Skip to content

Commit

Permalink
Add Support for AWS Bedrock (infiniflow#1408)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

infiniflow#308 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: KevinHuSh <kevinhu.sh@gmail.com>
  • Loading branch information
guoyuhao2330 and KevinHuSh authored Jul 8, 2024
1 parent cc92c10 commit e0b7040
Show file tree
Hide file tree
Showing 8 changed files with 325 additions and 7 deletions.
18 changes: 14 additions & 4 deletions api/apps/llm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,23 @@ def set_api_key():
def add_llm():
req = request.json
factory = req["llm_factory"]
# For VolcEngine, due to its special authentication method
# Assemble volc_ak, volc_sk, endpoint_id into api_key

if factory == "VolcEngine":
# For VolcEngine, due to its special authentication method
# Assemble volc_ak, volc_sk, endpoint_id into api_key
temp = list(eval(req["llm_name"]).items())[0]
llm_name = temp[0]
endpoint_id = temp[1]
api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \
f'"volc_sk": "{req.get("volc_sk", "")}", ' \
f'"ep_id": "{endpoint_id}", ' + '}'
elif factory == "Bedrock":
# For Bedrock, due to its special authentication method
# Assemble bedrock_ak, bedrock_sk, bedrock_region
llm_name = req["llm_name"]
api_key = '{' + f'"bedrock_ak": "{req.get("bedrock_ak", "")}", ' \
f'"bedrock_sk": "{req.get("bedrock_sk", "")}", ' \
f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}'
else:
llm_name = req["llm_name"]
api_key = "xxxxxxxxxxxxxxx"
Expand All @@ -134,7 +142,9 @@ def add_llm():
msg = ""
if llm["model_type"] == LLMType.EMBEDDING.value:
mdl = EmbeddingModel[factory](
key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock"] else None,
model_name=llm["llm_name"],
base_url=llm["api_base"])
try:
arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0 or tc == 0:
Expand All @@ -143,7 +153,7 @@ def add_llm():
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
elif llm["model_type"] == LLMType.CHAT.value:
mdl = ChatModel[factory](
key=llm['api_key'] if factory == "VolcEngine" else None,
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock"] else None,
model_name=llm["llm_name"],
base_url=llm["api_base"]
)
Expand Down
170 changes: 169 additions & 1 deletion api/db/init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ def init_superuser():
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "Bedrock",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
}
# {
# "name": "文心一言",
Expand Down Expand Up @@ -730,7 +735,170 @@ def init_llm_factory():
"max_tokens": 765,
"model_type": LLMType.IMAGE2TEXT.value
},

# ------------------------ Bedrock -----------------------
{
"fid": factory_infos[16]["name"],
"llm_name": "ai21.j2-ultra-v1",
"tags": "LLM,CHAT,8k",
"max_tokens": 8191,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "ai21.j2-mid-v1",
"tags": "LLM,CHAT,8k",
"max_tokens": 8191,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "cohere.command-text-v14",
"tags": "LLM,CHAT,4k",
"max_tokens": 4096,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "cohere.command-light-text-v14",
"tags": "LLM,CHAT,4k",
"max_tokens": 4096,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "cohere.command-r-v1:0",
"tags": "LLM,CHAT,128k",
"max_tokens": 128 * 1024,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "cohere.command-r-plus-v1:0",
"tags": "LLM,CHAT,128k",
"max_tokens": 128000,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "anthropic.claude-v2",
"tags": "LLM,CHAT,100k",
"max_tokens": 100 * 1024,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "anthropic.claude-v2:1",
"tags": "LLM,CHAT,200k",
"max_tokens": 200 * 1024,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "anthropic.claude-3-sonnet-20240229-v1:0",
"tags": "LLM,CHAT,200k",
"max_tokens": 200 * 1024,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"tags": "LLM,CHAT,200k",
"max_tokens": 200 * 1024,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "anthropic.claude-3-haiku-20240307-v1:0",
"tags": "LLM,CHAT,200k",
"max_tokens": 200 * 1024,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "anthropic.claude-3-opus-20240229-v1:0",
"tags": "LLM,CHAT,200k",
"max_tokens": 200 * 1024,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "anthropic.claude-instant-v1",
"tags": "LLM,CHAT,100k",
"max_tokens": 100 * 1024,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "amazon.titan-text-express-v1",
"tags": "LLM,CHAT,8k",
"max_tokens": 8192,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "amazon.titan-text-premier-v1:0",
"tags": "LLM,CHAT,32k",
"max_tokens": 32 * 1024,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "amazon.titan-text-lite-v1",
"tags": "LLM,CHAT,4k",
"max_tokens": 4096,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "meta.llama2-13b-chat-v1",
"tags": "LLM,CHAT,4k",
"max_tokens": 4096,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "meta.llama2-70b-chat-v1",
"tags": "LLM,CHAT,4k",
"max_tokens": 4096,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "meta.llama3-8b-instruct-v1:0",
"tags": "LLM,CHAT,8k",
"max_tokens": 8192,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "meta.llama3-70b-instruct-v1:0",
"tags": "LLM,CHAT,8k",
"max_tokens": 8192,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "mistral.mistral-7b-instruct-v0:2",
"tags": "LLM,CHAT,8k",
"max_tokens": 8192,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "mistral.mixtral-8x7b-instruct-v0:1",
"tags": "LLM,CHAT,4k",
"max_tokens": 4096,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "mistral.mistral-large-2402-v1:0",
"tags": "LLM,CHAT,8k",
"max_tokens": 8192,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "mistral.mistral-small-2402-v1:0",
"tags": "LLM,CHAT,8k",
"max_tokens": 8192,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "amazon.titan-embed-text-v2:0",
"tags": "TEXT EMBEDDING",
"max_tokens": 8192,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "cohere.embed-english-v3",
"tags": "TEXT EMBEDDING",
"max_tokens": 2048,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[16]["name"],
"llm_name": "cohere.embed-multilingual-v3",
"tags": "TEXT EMBEDDING",
"max_tokens": 2048,
"model_type": LLMType.EMBEDDING.value
},
]
for info in factory_infos:
try:
Expand Down
6 changes: 4 additions & 2 deletions rag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
"BaiChuan": BaiChuanEmbed,
"Jina": JinaEmbed,
"BAAI": DefaultEmbedding,
"Mistral": MistralEmbed
"Mistral": MistralEmbed,
"Bedrock": BedrockEmbed
}


Expand All @@ -58,7 +59,8 @@
"VolcEngine": VolcEngineChat,
"BaiChuan": BaiChuanChat,
"MiniMax": MiniMaxChat,
"Mistral": MistralChat
"Mistral": MistralChat,
"Bedrock": BedrockChat
}


Expand Down
87 changes: 87 additions & 0 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,3 +533,90 @@ def chat_streamly(self, system, history, gen_conf):
yield ans + "\n**ERROR**: " + str(e)

yield total_tokens


class BedrockChat(Base):

def __init__(self, key, model_name, **kwargs):
import boto3
from botocore.exceptions import ClientError
self.bedrock_ak = eval(key).get('bedrock_ak', '')
self.bedrock_sk = eval(key).get('bedrock_sk', '')
self.bedrock_region = eval(key).get('bedrock_region', '')
self.model_name = model_name
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)

def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
if "max_tokens" in gen_conf:
gen_conf["maxTokens"] = gen_conf["max_tokens"]
_ = gen_conf.pop("max_tokens")
if "top_p" in gen_conf:
gen_conf["topP"] = gen_conf["top_p"]
_ = gen_conf.pop("top_p")

try:
# Send the message to the model, using a basic inference configuration.
response = self.client.converse(
modelId=self.model_name,
messages=history,
inferenceConfig=gen_conf
)

# Extract and print the response text.
ans = response["output"]["message"]["content"][0]["text"]
return ans, num_tokens_from_string(ans)

except (ClientError, Exception) as e:
return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0

def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
if "max_tokens" in gen_conf:
gen_conf["maxTokens"] = gen_conf["max_tokens"]
_ = gen_conf.pop("max_tokens")
if "top_p" in gen_conf:
gen_conf["topP"] = gen_conf["top_p"]
_ = gen_conf.pop("top_p")

if self.model_name.split('.')[0] == 'ai21':
try:
response = self.client.converse(
modelId=self.model_name,
messages=history,
inferenceConfig=gen_conf
)
ans = response["output"]["message"]["content"][0]["text"]
return ans, num_tokens_from_string(ans)

except (ClientError, Exception) as e:
return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0

ans = ""
try:
# Send the message to the model, using a basic inference configuration.
streaming_response = self.client.converse_stream(
modelId=self.model_name,
messages=history,
inferenceConfig=gen_conf
)

# Extract and print the streamed response text in real-time.
for resp in streaming_response["stream"]:
if "contentBlockDelta" in resp:
ans += resp["contentBlockDelta"]["delta"]["text"]
yield ans

except (ClientError, Exception) as e:
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"

yield num_tokens_from_string(ans)
45 changes: 45 additions & 0 deletions rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,48 @@ def encode_queries(self, text):
res = self.client.embeddings(input=[truncate(text, 8196)],
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens


class BedrockEmbed(Base):
def __init__(self, key, model_name,
**kwargs):
import boto3
self.bedrock_ak = eval(key).get('bedrock_ak', '')
self.bedrock_sk = eval(key).get('bedrock_sk', '')
self.bedrock_region = eval(key).get('bedrock_region', '')
self.model_name = model_name
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)

def encode(self, texts: list, batch_size=32):
texts = [truncate(t, 8196) for t in texts]
embeddings = []
token_count = 0
for text in texts:
if self.model_name.split('.')[0] == 'amazon':
body = {"inputText": text}
elif self.model_name.split('.')[0] == 'cohere':
body = {"texts": [text], "input_type": 'search_document'}

response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
model_response = json.loads(response["body"].read())
embeddings.extend([model_response["embedding"]])
token_count += num_tokens_from_string(text)

return np.array(embeddings), token_count

def encode_queries(self, text):

embeddings = []
token_count = num_tokens_from_string(text)
if self.model_name.split('.')[0] == 'amazon':
body = {"inputText": truncate(text, 8196)}
elif self.model_name.split('.')[0] == 'cohere':
body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'}

response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
model_response = json.loads(response["body"].read())
embeddings.extend([model_response["embedding"]])

return np.array(embeddings), token_count

2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,6 @@ cn2an==0.5.22
roman-numbers==1.0.2
word2number==1.1
markdown==3.6
mistralai==0.4.2
boto3==1.34.140
duckduckgo_search==6.1.9
Loading

0 comments on commit e0b7040

Please sign in to comment.