From e0b7040f3c508430fe114fdfc9a44e264a7b24ae Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Mon, 8 Jul 2024 09:37:34 +0800 Subject: [PATCH] Add Support for AWS Bedrock (#1408) ### What problem does this PR solve? #308 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: KevinHuSh --- api/apps/llm_app.py | 18 +++- api/db/init_data.py | 170 ++++++++++++++++++++++++++++++++++++- rag/llm/__init__.py | 6 +- rag/llm/chat_model.py | 87 +++++++++++++++++++ rag/llm/embedding_model.py | 45 ++++++++++ requirements.txt | 2 + requirements_arm.txt | 2 + requirements_dev.txt | 2 + 8 files changed, 325 insertions(+), 7 deletions(-) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index c4a245ffb35..678f74ab722 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -109,15 +109,23 @@ def set_api_key(): def add_llm(): req = request.json factory = req["llm_factory"] - # For VolcEngine, due to its special authentication method - # Assemble volc_ak, volc_sk, endpoint_id into api_key + if factory == "VolcEngine": + # For VolcEngine, due to its special authentication method + # Assemble volc_ak, volc_sk, endpoint_id into api_key temp = list(eval(req["llm_name"]).items())[0] llm_name = temp[0] endpoint_id = temp[1] api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \ f'"volc_sk": "{req.get("volc_sk", "")}", ' \ f'"ep_id": "{endpoint_id}", ' + '}' + elif factory == "Bedrock": + # For Bedrock, due to its special authentication method + # Assemble bedrock_ak, bedrock_sk, bedrock_region + llm_name = req["llm_name"] + api_key = '{' + f'"bedrock_ak": "{req.get("bedrock_ak", "")}", ' \ + f'"bedrock_sk": "{req.get("bedrock_sk", "")}", ' \ + f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}' else: llm_name = req["llm_name"] api_key = "xxxxxxxxxxxxxxx" @@ -134,7 +142,9 @@ def add_llm(): msg = "" if llm["model_type"] == LLMType.EMBEDDING.value: mdl = EmbeddingModel[factory]( - key=None, model_name=llm["llm_name"], base_url=llm["api_base"]) + key=llm['api_key'] if factory in ["VolcEngine", "Bedrock"] else None, + model_name=llm["llm_name"], + base_url=llm["api_base"]) try: arr, tc = mdl.encode(["Test if the api key is available"]) if len(arr[0]) == 0 or tc == 0: @@ -143,7 +153,7 @@ def add_llm(): msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e) elif llm["model_type"] == LLMType.CHAT.value: mdl = ChatModel[factory]( - key=llm['api_key'] if factory == "VolcEngine" else None, + key=llm['api_key'] if factory in ["VolcEngine", "Bedrock"] else None, model_name=llm["llm_name"], base_url=llm["api_base"] ) diff --git a/api/db/init_data.py b/api/db/init_data.py index 90d99f762d3..5405fc3e732 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -170,6 +170,11 @@ def init_superuser(): "logo": "", "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", "status": "1", +},{ + "name": "Bedrock", + "logo": "", + "tags": "LLM,TEXT EMBEDDING", + "status": "1", } # { # "name": "文心一言", @@ -730,7 +735,170 @@ def init_llm_factory(): "max_tokens": 765, "model_type": LLMType.IMAGE2TEXT.value }, - + # ------------------------ Bedrock ----------------------- + { + "fid": factory_infos[16]["name"], + "llm_name": "ai21.j2-ultra-v1", + "tags": "LLM,CHAT,8k", + "max_tokens": 8191, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "ai21.j2-mid-v1", + "tags": "LLM,CHAT,8k", + "max_tokens": 8191, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "cohere.command-text-v14", + "tags": "LLM,CHAT,4k", + "max_tokens": 4096, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "cohere.command-light-text-v14", + "tags": "LLM,CHAT,4k", + "max_tokens": 4096, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "cohere.command-r-v1:0", + "tags": "LLM,CHAT,128k", + "max_tokens": 128 * 1024, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "cohere.command-r-plus-v1:0", + "tags": "LLM,CHAT,128k", + "max_tokens": 128000, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "anthropic.claude-v2", + "tags": "LLM,CHAT,100k", + "max_tokens": 100 * 1024, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "anthropic.claude-v2:1", + "tags": "LLM,CHAT,200k", + "max_tokens": 200 * 1024, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "anthropic.claude-3-sonnet-20240229-v1:0", + "tags": "LLM,CHAT,200k", + "max_tokens": 200 * 1024, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "tags": "LLM,CHAT,200k", + "max_tokens": 200 * 1024, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "anthropic.claude-3-haiku-20240307-v1:0", + "tags": "LLM,CHAT,200k", + "max_tokens": 200 * 1024, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "anthropic.claude-3-opus-20240229-v1:0", + "tags": "LLM,CHAT,200k", + "max_tokens": 200 * 1024, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "anthropic.claude-instant-v1", + "tags": "LLM,CHAT,100k", + "max_tokens": 100 * 1024, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "amazon.titan-text-express-v1", + "tags": "LLM,CHAT,8k", + "max_tokens": 8192, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "amazon.titan-text-premier-v1:0", + "tags": "LLM,CHAT,32k", + "max_tokens": 32 * 1024, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "amazon.titan-text-lite-v1", + "tags": "LLM,CHAT,4k", + "max_tokens": 4096, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "meta.llama2-13b-chat-v1", + "tags": "LLM,CHAT,4k", + "max_tokens": 4096, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "meta.llama2-70b-chat-v1", + "tags": "LLM,CHAT,4k", + "max_tokens": 4096, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "meta.llama3-8b-instruct-v1:0", + "tags": "LLM,CHAT,8k", + "max_tokens": 8192, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "meta.llama3-70b-instruct-v1:0", + "tags": "LLM,CHAT,8k", + "max_tokens": 8192, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "mistral.mistral-7b-instruct-v0:2", + "tags": "LLM,CHAT,8k", + "max_tokens": 8192, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "mistral.mixtral-8x7b-instruct-v0:1", + "tags": "LLM,CHAT,4k", + "max_tokens": 4096, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "mistral.mistral-large-2402-v1:0", + "tags": "LLM,CHAT,8k", + "max_tokens": 8192, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "mistral.mistral-small-2402-v1:0", + "tags": "LLM,CHAT,8k", + "max_tokens": 8192, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "amazon.titan-embed-text-v2:0", + "tags": "TEXT EMBEDDING", + "max_tokens": 8192, + "model_type": LLMType.EMBEDDING.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "cohere.embed-english-v3", + "tags": "TEXT EMBEDDING", + "max_tokens": 2048, + "model_type": LLMType.EMBEDDING.value + }, { + "fid": factory_infos[16]["name"], + "llm_name": "cohere.embed-multilingual-v3", + "tags": "TEXT EMBEDDING", + "max_tokens": 2048, + "model_type": LLMType.EMBEDDING.value + }, ] for info in factory_infos: try: diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 9127d0c979d..513c1c44e73 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -31,7 +31,8 @@ "BaiChuan": BaiChuanEmbed, "Jina": JinaEmbed, "BAAI": DefaultEmbedding, - "Mistral": MistralEmbed + "Mistral": MistralEmbed, + "Bedrock": BedrockEmbed } @@ -58,7 +59,8 @@ "VolcEngine": VolcEngineChat, "BaiChuan": BaiChuanChat, "MiniMax": MiniMaxChat, - "Mistral": MistralChat + "Mistral": MistralChat, + "Bedrock": BedrockChat } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 75a08509f11..7f70e45bc51 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -533,3 +533,90 @@ def chat_streamly(self, system, history, gen_conf): yield ans + "\n**ERROR**: " + str(e) yield total_tokens + + +class BedrockChat(Base): + + def __init__(self, key, model_name, **kwargs): + import boto3 + from botocore.exceptions import ClientError + self.bedrock_ak = eval(key).get('bedrock_ak', '') + self.bedrock_sk = eval(key).get('bedrock_sk', '') + self.bedrock_region = eval(key).get('bedrock_region', '') + self.model_name = model_name + self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, + aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) + + def chat(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + for k in list(gen_conf.keys()): + if k not in ["temperature", "top_p", "max_tokens"]: + del gen_conf[k] + if "max_tokens" in gen_conf: + gen_conf["maxTokens"] = gen_conf["max_tokens"] + _ = gen_conf.pop("max_tokens") + if "top_p" in gen_conf: + gen_conf["topP"] = gen_conf["top_p"] + _ = gen_conf.pop("top_p") + + try: + # Send the message to the model, using a basic inference configuration. + response = self.client.converse( + modelId=self.model_name, + messages=history, + inferenceConfig=gen_conf + ) + + # Extract and print the response text. + ans = response["output"]["message"]["content"][0]["text"] + return ans, num_tokens_from_string(ans) + + except (ClientError, Exception) as e: + return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0 + + def chat_streamly(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + for k in list(gen_conf.keys()): + if k not in ["temperature", "top_p", "max_tokens"]: + del gen_conf[k] + if "max_tokens" in gen_conf: + gen_conf["maxTokens"] = gen_conf["max_tokens"] + _ = gen_conf.pop("max_tokens") + if "top_p" in gen_conf: + gen_conf["topP"] = gen_conf["top_p"] + _ = gen_conf.pop("top_p") + + if self.model_name.split('.')[0] == 'ai21': + try: + response = self.client.converse( + modelId=self.model_name, + messages=history, + inferenceConfig=gen_conf + ) + ans = response["output"]["message"]["content"][0]["text"] + return ans, num_tokens_from_string(ans) + + except (ClientError, Exception) as e: + return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0 + + ans = "" + try: + # Send the message to the model, using a basic inference configuration. + streaming_response = self.client.converse_stream( + modelId=self.model_name, + messages=history, + inferenceConfig=gen_conf + ) + + # Extract and print the streamed response text in real-time. + for resp in streaming_response["stream"]: + if "contentBlockDelta" in resp: + ans += resp["contentBlockDelta"]["delta"]["text"] + yield ans + + except (ClientError, Exception) as e: + yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}" + + yield num_tokens_from_string(ans) diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index eeba7f7b9ed..48081e0124d 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -374,3 +374,48 @@ def encode_queries(self, text): res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name) return np.array(res.data[0].embedding), res.usage.total_tokens + + +class BedrockEmbed(Base): + def __init__(self, key, model_name, + **kwargs): + import boto3 + self.bedrock_ak = eval(key).get('bedrock_ak', '') + self.bedrock_sk = eval(key).get('bedrock_sk', '') + self.bedrock_region = eval(key).get('bedrock_region', '') + self.model_name = model_name + self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, + aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) + + def encode(self, texts: list, batch_size=32): + texts = [truncate(t, 8196) for t in texts] + embeddings = [] + token_count = 0 + for text in texts: + if self.model_name.split('.')[0] == 'amazon': + body = {"inputText": text} + elif self.model_name.split('.')[0] == 'cohere': + body = {"texts": [text], "input_type": 'search_document'} + + response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) + model_response = json.loads(response["body"].read()) + embeddings.extend([model_response["embedding"]]) + token_count += num_tokens_from_string(text) + + return np.array(embeddings), token_count + + def encode_queries(self, text): + + embeddings = [] + token_count = num_tokens_from_string(text) + if self.model_name.split('.')[0] == 'amazon': + body = {"inputText": truncate(text, 8196)} + elif self.model_name.split('.')[0] == 'cohere': + body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'} + + response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) + model_response = json.loads(response["body"].read()) + embeddings.extend([model_response["embedding"]]) + + return np.array(embeddings), token_count + diff --git a/requirements.txt b/requirements.txt index 934cb64ed7a..048761dabf2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -144,4 +144,6 @@ cn2an==0.5.22 roman-numbers==1.0.2 word2number==1.1 markdown==3.6 +mistralai==0.4.2 +boto3==1.34.140 duckduckgo_search==6.1.9 diff --git a/requirements_arm.txt b/requirements_arm.txt index 51072682bd4..448b093619b 100644 --- a/requirements_arm.txt +++ b/requirements_arm.txt @@ -145,4 +145,6 @@ cn2an==0.5.22 roman-numbers==1.0.2 word2number==1.1 markdown==3.6 +mistralai==0.4.2 +boto3==1.34.140 duckduckgo_search==6.1.9 diff --git a/requirements_dev.txt b/requirements_dev.txt index 003d7ae9c02..8bc49f07ef0 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -130,4 +130,6 @@ cn2an==0.5.22 roman-numbers==1.0.2 word2number==1.1 markdown==3.6 +mistralai==0.4.2 +boto3==1.34.140 duckduckgo_search==6.1.9