Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 增加xinference模型对接 #959

Merged
merged 1 commit into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
VolcanicEngineModelProvider
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
from setting.models_provider.impl.xinference_model_provider.xinference_model_provider import XinferenceModelProvider
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
from setting.models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider

Expand All @@ -40,3 +41,4 @@ class ModelProvideConstants(Enum):
model_tencent_provider = TencentModelProvider()
model_aws_bedrock_provider = BedrockModelProvider()
model_local_provider = LocalModelProvider()
model_xinference_provider = XinferenceModelProvider()
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,56 @@ def _get_aws_bedrock_icon_path():


def _initialize_model_info():
model_info_list = [_create_model_info(
'amazon.titan-text-premier-v1:0',
'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
model_info_list = [
_create_model_info(
'anthropic.claude-v2:1',
'Claude 2 的更新,采用双倍的上下文窗口,并在长文档和 RAG 上下文中提高可靠性、幻觉率和循证准确性。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-v2',
'Anthropic 功能强大的模型,可处理各种任务,从复杂的对话和创意内容生成到详细的指令服从。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-3-haiku-20240307-v1:0',
'Claude 3 Haiku 是 Anthropic 最快速、最紧凑的模型,具有近乎即时的响应能力。该模型可以快速回答简单的查询和请求。客户将能够构建模仿人类交互的无缝人工智能体验。 Claude 3 Haiku 可以处理图像和返回文本输出,并且提供 200K 上下文窗口。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-3-sonnet-20240229-v1:0',
'Anthropic 推出的 Claude 3 Sonnet 模型在智能和速度之间取得理想的平衡,尤其是在处理企业工作负载方面。该模型提供最大的效用,同时价格低于竞争产品,并且其经过精心设计,是大规模部署人工智能的可靠选择。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-3-5-sonnet-20240620-v1:0',
'Claude 3.5 Sonnet提高了智能的行业标准,在广泛的评估中超越了竞争对手的型号和Claude 3 Opus,具有我们中端型号的速度和成本效益。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-instant-v1',
'一种更快速、更实惠但仍然非常强大的模型,它可以处理一系列任务,包括随意对话、文本分析、摘要和文档问题回答。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'amazon.titan-text-premier-v1:0',
'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'amazon.titan-text-lite-v1',
'Amazon Titan Text Lite 是一种轻量级的高效模型,非常适合英语任务的微调,包括摘要和文案写作等,在这种场景下,客户需要更小、更经济高效且高度可定制的模型',
Expand All @@ -59,7 +102,7 @@ def _initialize_model_info():
_create_model_info(
'mistral.mistral-7b-instruct-v0:2',
'7B 密集型转换器,可快速部署,易于定制。体积虽小,但功能强大,适用于各种用例。支持英语和代码,以及 32k 的上下文窗口。',
ModelTypeConst.EMBEDDING,
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(
Expand Down
78 changes: 78 additions & 0 deletions apps/setting/models_provider/impl/base_chat_open_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# coding=utf-8

from typing import List, Dict, Optional, Any, Iterator, Type
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk


class BaseChatOpenAI(ChatOpenAI):

def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return self.get_last_generation_info().get('prompt_tokens', 0)

def get_num_tokens(self, text: str) -> int:
return self.get_last_generation_info().get('completion_tokens', 0)

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
base_generation_info = {}
with response:
is_first_chunk = True
for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
if token_usage := chunk.get("usage"):
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
logprobs = None
else:
continue
else:
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {**base_generation_info} if is_first_chunk else {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint

logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
is_first_chunk = False
yield generation_chunk
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,11 @@
"""
from typing import List, Dict

from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai import ChatOpenAI

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI


class DeepSeekChatModel(MaxKBBaseModel, ChatOpenAI):
class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
deepseek_chat_open_ai = DeepSeekChatModel(
Expand All @@ -25,10 +22,3 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
)
return deepseek_chat_open_ai

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
"""
from typing import List, Dict

from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI


class KimiChatModel(MaxKBBaseModel, ChatOpenAI):
class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
kimi_chat_open_ai = KimiChatModel(
Expand All @@ -25,10 +25,3 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
)
return kimi_chat_open_ai

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from typing import List, Dict
from urllib.parse import urlparse, ParseResult

from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI


def get_base_url(url: str):
Expand All @@ -24,19 +24,11 @@ def get_base_url(url: str):
return result_url[:-1] if result_url.endswith("/") else result_url


class OllamaChatModel(MaxKBBaseModel, ChatOpenAI):
class OllamaChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
api_base = model_credential.get('api_base', '')
base_url = get_base_url(api_base)
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
return OllamaChatModel(model=model_name, openai_api_base=base_url,
openai_api_key=model_credential.get('api_key'))

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,19 @@
"""
from typing import List, Dict

from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai import ChatOpenAI

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI


class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
class OpenAIChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
azure_chat_open_ai = OpenAIChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key')
openai_api_key=model_credential.get('api_key'),
streaming=model_kwargs.get('streaming', False),
max_tokens=model_kwargs.get('max_tokens', 5),
temperature=model_kwargs.get('temperature', 0.5),
)
return azure_chat_open_ai

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
40 changes: 24 additions & 16 deletions apps/setting/models_provider/impl/xf_model_provider/model/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@
@date:2024/04/19 15:55
@desc:
"""

import json
from typing import List, Optional, Any, Iterator, Dict

from langchain_community.chat_models import ChatSparkLLM
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk, \
ChatSparkLLM
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessageChunk, get_buffer_string
from langchain_core.messages import BaseMessage, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel


Expand All @@ -31,16 +30,19 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
spark_api_key=model_credential.get('spark_api_key'),
spark_api_secret=model_credential.get('spark_api_secret'),
spark_api_url=model_credential.get('spark_api_url'),
spark_llm_domain=model_name
spark_llm_domain=model_name,
temperature=model_kwargs.get('temperature', 0.5),
max_tokens=model_kwargs.get('max_tokens', 5),
)

def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
return self.get_last_generation_info().get('prompt_tokens', 0)

def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
return self.get_last_generation_info().get('completion_tokens', 0)

def _stream(
self,
Expand All @@ -58,11 +60,17 @@ def _stream(
True,
)
for content in self.client.subscribe(timeout=self.request_timeout):
if "data" not in content:
if "data" in content:
delta = content["data"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
cg_chunk = ChatGenerationChunk(message=chunk)
elif "usage" in content:
generation_info = content["usage"]
self.__dict__.setdefault('_last_generation_info', {}).update(generation_info)
continue
else:
continue
delta = content["data"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
if cg_chunk is not None:
if run_manager:
run_manager.on_llm_new_token(str(cg_chunk.message.content), chunk=cg_chunk)
yield cg_chunk
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# coding=utf-8
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# coding=utf-8
from typing import Dict

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding


class XinferenceEmbeddingModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
try:
model_list = provider.get_base_model_list(model_credential.get('api_base'), 'embedding')
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
exist = provider.get_model_info_by_name(model_list, model_name)
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
if len(exist) == 0:
model.start_down_model_thread()
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
model.embed_query('你好')
return True

def encryption_dict(self, model_info: Dict[str, object]):
return model_info

def build_model(self, model_info: Dict[str, object]):
for key in ['model']:
if key not in model_info:
raise AppApiException(500, f'{key} 字段为必填字段')
return self

api_base = forms.TextInputField('API 域名', required=True)
Loading
Loading