Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for Gemini #1465

Merged
merged 1 commit into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion api/db/init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ def init_superuser():
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
},{
"name": "Gemini",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
"status": "1",
}
# {
# "name": "文心一言",
Expand Down Expand Up @@ -898,7 +903,37 @@ def init_llm_factory():
"tags": "TEXT EMBEDDING",
"max_tokens": 2048,
"model_type": LLMType.EMBEDDING.value
},
}, {
"fid": factory_infos[17]["name"],
"llm_name": "gemini-1.5-pro-latest",
"tags": "LLM,CHAT,1024K",
"max_tokens": 1024*1024,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[17]["name"],
"llm_name": "gemini-1.5-flash-latest",
"tags": "LLM,CHAT,1024K",
"max_tokens": 1024*1024,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[17]["name"],
"llm_name": "gemini-1.0-pro",
"tags": "LLM,CHAT,30K",
"max_tokens": 30*1024,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[17]["name"],
"llm_name": "gemini-1.0-pro-vision-latest",
"tags": "LLM,IMAGE2TEXT,12K",
"max_tokens": 12*1024,
"model_type": LLMType.IMAGE2TEXT.value
}, {
"fid": factory_infos[17]["name"],
"llm_name": "text-embedding-004",
"tags": "TEXT EMBEDDING",
"max_tokens": 2048,
"model_type": LLMType.EMBEDDING.value
}
]
for info in factory_infos:
try:
Expand Down
61 changes: 61 additions & 0 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,3 +621,64 @@ def chat_streamly(self, system, history, gen_conf):
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"

yield num_tokens_from_string(ans)

class GeminiChat(Base):

def __init__(self, key, model_name,base_url=None):
from google.generativeai import client,GenerativeModel

client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = 'models/' + model_name
self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client

def chat(self,system,history,gen_conf):
if system:
history.insert(0, {"role": "user", "parts": system})
if 'max_tokens' in gen_conf:
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_output_tokens"]:
del gen_conf[k]
for item in history:
if 'role' in item and item['role'] == 'assistant':
item['role'] = 'model'
if 'content' in item :
item['parts'] = item.pop('content')

try:
response = self.model.generate_content(
history,
generation_config=gen_conf)
ans = response.text
return ans, response.usage_metadata.total_token_count
except Exception as e:
return "**ERROR**: " + str(e), 0

def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "user", "parts": system})
if 'max_tokens' in gen_conf:
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_output_tokens"]:
del gen_conf[k]
for item in history:
if 'role' in item and item['role'] == 'assistant':
item['role'] = 'model'
if 'content' in item :
item['parts'] = item.pop('content')
ans = ""
try:
response = self.model.generate_content(
history,
generation_config=gen_conf,stream=True)
for resp in response:
ans += resp.text
yield ans

except Exception as e:
yield ans + "\n**ERROR**: " + str(e)

yield response._chunks[-1].usage_metadata.total_token_count
23 changes: 23 additions & 0 deletions rag/llm/cv_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,29 @@ def describe(self, image, max_tokens=300):
)
return res.choices[0].message.content.strip(), res.usage.total_tokens

class GeminiCV(Base):
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
from google.generativeai import client,GenerativeModel
client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = model_name
self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client
self.lang = lang

def describe(self, image, max_tokens=2048):
from PIL.Image import open
gen_config = {'max_output_tokens':max_tokens}
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
b64 = self.image2base64(image)
img = open(BytesIO(base64.b64decode(b64)))
input = [prompt,img]
res = self.model.generate_content(
input,
generation_config=gen_config,
)
return res.text,res.usage_metadata.total_token_count

class LocalCV(Base):
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
Expand Down
26 changes: 25 additions & 1 deletion rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import asyncio
from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate

import google.generativeai as genai

class Base(ABC):
def __init__(self, key, model_name):
Expand Down Expand Up @@ -419,3 +419,27 @@ def encode_queries(self, text):

return np.array(embeddings), token_count

class GeminiEmbed(Base):
def __init__(self, key, model_name='models/text-embedding-004',
**kwargs):
genai.configure(api_key=key)
self.model_name = 'models/' + model_name

def encode(self, texts: list, batch_size=32):
texts = [truncate(t, 2048) for t in texts]
token_count = sum(num_tokens_from_string(text) for text in texts)
result = genai.embed_content(
model=self.model_name,
content=texts,
task_type="retrieval_document",
title="Embedding of list of strings")
return np.array(result['embedding']),token_count

def encode_queries(self, text):
result = genai.embed_content(
model=self.model_name,
content=truncate(text,2048),
task_type="retrieval_document",
title="Embedding of single string")
token_count = num_tokens_from_string(text)
return np.array(result['embedding']),token_count
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,4 @@ markdown==3.6
mistralai==0.4.2
boto3==1.34.140
duckduckgo_search==6.1.9
google-generativeai==0.7.2
1 change: 1 addition & 0 deletions requirements_arm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,4 @@ markdown==3.6
mistralai==0.4.2
boto3==1.34.140
duckduckgo_search==6.1.9
google-generativeai==0.7.2
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,4 @@ markdown==3.6
mistralai==0.4.2
boto3==1.34.140
duckduckgo_search==6.1.9
google-generativeai==0.7.2
114 changes: 114 additions & 0 deletions web/src/assets/svg/llm/gemini.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions web/src/pages/user-setting/setting-model/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ const IconMap = {
Mistral: 'mistral',
'Azure-OpenAI': 'azure',
Bedrock: 'bedrock',
Gemini:'gemini',
};

const LlmIcon = ({ name }: { name: string }) => {
Expand Down