Skip to content

Commit

Permalink
add support for Replicate (#1980)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

#1853  add support for Replicate

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
  • Loading branch information
hangters and aopstudio authored Aug 19, 2024
1 parent be5a678 commit 79426fc
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 12 deletions.
4 changes: 2 additions & 2 deletions api/apps/llm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def add_llm():
msg = ""
if llm["model_type"] == LLMType.EMBEDDING.value:
mdl = EmbeddingModel[factory](
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None,
model_name=llm["llm_name"],
base_url=llm["api_base"])
try:
Expand All @@ -160,7 +160,7 @@ def add_llm():
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
elif llm["model_type"] == LLMType.CHAT.value:
mdl = ChatModel[factory](
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None,
model_name=llm["llm_name"],
base_url=llm["api_base"]
)
Expand Down
7 changes: 7 additions & 0 deletions conf/llm_factories.json
Original file line number Diff line number Diff line change
Expand Up @@ -3113,6 +3113,13 @@
"model_type": "image2text"
}
]
},
{
"name": "Replicate",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
"llm": []
}
]
}
6 changes: 4 additions & 2 deletions rag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
"TogetherAI": TogetherAIEmbed,
"PerfXCloud": PerfXCloudEmbed,
"Upstage": UpstageEmbed,
"SILICONFLOW": SILICONFLOWEmbed
"SILICONFLOW": SILICONFLOWEmbed,
"Replicate": ReplicateEmbed
}


Expand Down Expand Up @@ -96,7 +97,8 @@
"Upstage":UpstageChat,
"novita.ai": NovitaAIChat,
"SILICONFLOW": SILICONFLOWChat,
"01.AI": YiChat
"01.AI": YiChat,
"Replicate": ReplicateChat
}


Expand Down
55 changes: 53 additions & 2 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
base_url = "https://api.together.xyz/v1"
super().__init__(key, model_name, base_url)


class PerfXCloudChat(Base):
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
if not base_url:
Expand Down Expand Up @@ -1036,4 +1036,55 @@ class YiChat(Base):
def __init__(self, key, model_name, base_url="https://api.01.ai/v1"):
if not base_url:
base_url = "https://api.01.ai/v1"
super().__init__(key, model_name, base_url)
super().__init__(key, model_name, base_url)


class ReplicateChat(Base):
def __init__(self, key, model_name, base_url=None):
from replicate.client import Client

self.model_name = model_name
self.client = Client(api_token=key)
self.system = ""

def chat(self, system, history, gen_conf):
if "max_tokens" in gen_conf:
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
if system:
self.system = system
prompt = "\n".join(
[item["role"] + ":" + item["content"] for item in history[-5:]]
)
ans = ""
try:
response = self.client.run(
self.model_name,
input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
)
ans = "".join(response)
return ans, num_tokens_from_string(ans)
except Exception as e:
return ans + "\n**ERROR**: " + str(e), 0

def chat_streamly(self, system, history, gen_conf):
if "max_tokens" in gen_conf:
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
if system:
self.system = system
prompt = "\n".join(
[item["role"] + ":" + item["content"] for item in history[-5:]]
)
ans = ""
try:
response = self.client.run(
self.model_name,
input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
)
for resp in response:
ans += resp
yield ans

except Exception as e:
yield ans + "\n**ERROR**: " + str(e)

yield num_tokens_from_string(ans)
22 changes: 20 additions & 2 deletions rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
base_url = "https://api.together.xyz/v1"
super().__init__(key, model_name, base_url)


class PerfXCloudEmbed(OpenAIEmbed):
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
if not base_url:
Expand All @@ -580,4 +580,22 @@ class SILICONFLOWEmbed(OpenAIEmbed):
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
super().__init__(key, model_name, base_url)
super().__init__(key, model_name, base_url)


class ReplicateEmbed(Base):
def __init__(self, key, model_name, base_url=None):
from replicate.client import Client

self.model_name = model_name
self.client = Client(api_token=key)

def encode(self, texts: list, batch_size=32):
from json import dumps

res = self.client.run(self.model_name, input={"texts": dumps(texts)})
return np.array(res), sum([num_tokens_from_string(text) for text in texts])

def encode_queries(self, text):
res = self.client.embed(self.model_name, input={"texts": [text]})
return np.array(res), num_tokens_from_string(text)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ python_pptx==0.6.23
readability_lxml==0.8.1
redis==5.0.3
Requests==2.32.2
replicate==0.31.0
roman_numbers==1.0.2
ruamel.base==1.0.0
scholarly==1.7.11
Expand All @@ -87,4 +88,4 @@ wikipedia==1.4.0
word2number==1.1
xgboost==2.1.0
xpinyin==0.7.6
zhipuai==2.0.1
zhipuai==2.0.1
3 changes: 2 additions & 1 deletion requirements_arm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ python-pptx==0.6.23
PyYAML==6.0.1
redis==5.0.3
regex==2023.12.25
replicate==0.31.0
requests==2.32.2
ruamel.yaml==0.18.6
ruamel.yaml.clib==0.2.8
Expand Down Expand Up @@ -161,4 +162,4 @@ markdown_to_json==2.1.1
scholarly==1.7.11
deepl==1.18.0
psycopg2-binary==2.9.9
tabulate-0.9.0
tabulate-0.9.0
1 change: 1 addition & 0 deletions web/src/assets/svg/llm/replicate.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion web/src/pages/user-setting/constants.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ export const UserSettingIconMap = {

export * from '@/constants/setting';

export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible",'TogetherAI'];
export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible",'TogetherAI','Replicate'];
3 changes: 2 additions & 1 deletion web/src/pages/user-setting/setting-model/constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ export const IconMap = {
Upstage: 'upstage',
'novita.ai': 'novita-ai',
SILICONFLOW: 'siliconflow',
"01.AI": 'yi'
"01.AI": 'yi',
"Replicate": 'replicate'
};

export const BedrockRegionList = [
Expand Down

0 comments on commit 79426fc

Please sign in to comment.