Skip to content

Commit

Permalink
add support for Gemini (infiniflow#1465)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

infiniflow#1036

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
  • Loading branch information
hangters and aopstudio authored Jul 11, 2024
1 parent 8fcb86c commit ef80633
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 2 deletions.
37 changes: 36 additions & 1 deletion api/db/init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ def init_superuser():
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
},{
"name": "Gemini",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
"status": "1",
}
# {
# "name": "文心一言",
Expand Down Expand Up @@ -898,7 +903,37 @@ def init_llm_factory():
"tags": "TEXT EMBEDDING",
"max_tokens": 2048,
"model_type": LLMType.EMBEDDING.value
},
}, {
"fid": factory_infos[17]["name"],
"llm_name": "gemini-1.5-pro-latest",
"tags": "LLM,CHAT,1024K",
"max_tokens": 1024*1024,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[17]["name"],
"llm_name": "gemini-1.5-flash-latest",
"tags": "LLM,CHAT,1024K",
"max_tokens": 1024*1024,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[17]["name"],
"llm_name": "gemini-1.0-pro",
"tags": "LLM,CHAT,30K",
"max_tokens": 30*1024,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[17]["name"],
"llm_name": "gemini-1.0-pro-vision-latest",
"tags": "LLM,IMAGE2TEXT,12K",
"max_tokens": 12*1024,
"model_type": LLMType.IMAGE2TEXT.value
}, {
"fid": factory_infos[17]["name"],
"llm_name": "text-embedding-004",
"tags": "TEXT EMBEDDING",
"max_tokens": 2048,
"model_type": LLMType.EMBEDDING.value
}
]
for info in factory_infos:
try:
Expand Down
61 changes: 61 additions & 0 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,3 +621,64 @@ def chat_streamly(self, system, history, gen_conf):
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"

yield num_tokens_from_string(ans)

class GeminiChat(Base):

def __init__(self, key, model_name,base_url=None):
from google.generativeai import client,GenerativeModel

client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = 'models/' + model_name
self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client

def chat(self,system,history,gen_conf):
if system:
history.insert(0, {"role": "user", "parts": system})
if 'max_tokens' in gen_conf:
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_output_tokens"]:
del gen_conf[k]
for item in history:
if 'role' in item and item['role'] == 'assistant':
item['role'] = 'model'
if 'content' in item :
item['parts'] = item.pop('content')

try:
response = self.model.generate_content(
history,
generation_config=gen_conf)
ans = response.text
return ans, response.usage_metadata.total_token_count
except Exception as e:
return "**ERROR**: " + str(e), 0

def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "user", "parts": system})
if 'max_tokens' in gen_conf:
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_output_tokens"]:
del gen_conf[k]
for item in history:
if 'role' in item and item['role'] == 'assistant':
item['role'] = 'model'
if 'content' in item :
item['parts'] = item.pop('content')
ans = ""
try:
response = self.model.generate_content(
history,
generation_config=gen_conf,stream=True)
for resp in response:
ans += resp.text
yield ans

except Exception as e:
yield ans + "\n**ERROR**: " + str(e)

yield response._chunks[-1].usage_metadata.total_token_count
23 changes: 23 additions & 0 deletions rag/llm/cv_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,29 @@ def describe(self, image, max_tokens=300):
)
return res.choices[0].message.content.strip(), res.usage.total_tokens

class GeminiCV(Base):
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
from google.generativeai import client,GenerativeModel
client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = model_name
self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client
self.lang = lang

def describe(self, image, max_tokens=2048):
from PIL.Image import open
gen_config = {'max_output_tokens':max_tokens}
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
b64 = self.image2base64(image)
img = open(BytesIO(base64.b64decode(b64)))
input = [prompt,img]
res = self.model.generate_content(
input,
generation_config=gen_config,
)
return res.text,res.usage_metadata.total_token_count

class LocalCV(Base):
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
Expand Down
26 changes: 25 additions & 1 deletion rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import asyncio
from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate

import google.generativeai as genai

class Base(ABC):
def __init__(self, key, model_name):
Expand Down Expand Up @@ -419,3 +419,27 @@ def encode_queries(self, text):

return np.array(embeddings), token_count

class GeminiEmbed(Base):
def __init__(self, key, model_name='models/text-embedding-004',
**kwargs):
genai.configure(api_key=key)
self.model_name = 'models/' + model_name

def encode(self, texts: list, batch_size=32):
texts = [truncate(t, 2048) for t in texts]
token_count = sum(num_tokens_from_string(text) for text in texts)
result = genai.embed_content(
model=self.model_name,
content=texts,
task_type="retrieval_document",
title="Embedding of list of strings")
return np.array(result['embedding']),token_count

def encode_queries(self, text):
result = genai.embed_content(
model=self.model_name,
content=truncate(text,2048),
task_type="retrieval_document",
title="Embedding of single string")
token_count = num_tokens_from_string(text)
return np.array(result['embedding']),token_count
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,4 @@ markdown==3.6
mistralai==0.4.2
boto3==1.34.140
duckduckgo_search==6.1.9
google-generativeai==0.7.2
1 change: 1 addition & 0 deletions requirements_arm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,4 @@ markdown==3.6
mistralai==0.4.2
boto3==1.34.140
duckduckgo_search==6.1.9
google-generativeai==0.7.2
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,4 @@ markdown==3.6
mistralai==0.4.2
boto3==1.34.140
duckduckgo_search==6.1.9
google-generativeai==0.7.2
114 changes: 114 additions & 0 deletions web/src/assets/svg/llm/gemini.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions web/src/pages/user-setting/setting-model/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ const IconMap = {
Mistral: 'mistral',
'Azure-OpenAI': 'azure',
Bedrock: 'bedrock',
Gemini:'gemini',
};

const LlmIcon = ({ name }: { name: string }) => {
Expand Down

0 comments on commit ef80633

Please sign in to comment.