Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

solve knowledgegraph issue when calling gemini model #2738

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 64 additions & 62 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from rag.nlp import is_english
from rag.utils import num_tokens_from_string
from groq import Groq
import os
import os
import json
import requests
import asyncio
Expand Down Expand Up @@ -62,17 +62,17 @@ def chat_streamly(self, system, history, gen_conf):
stream=True,
**gen_conf)
for resp in response:
if not resp.choices:continue
if not resp.choices: continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
total_tokens = (
(
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
if not hasattr(resp, "usage") or not resp.usage
else resp.usage.get("total_tokens",total_tokens)
else resp.usage.get("total_tokens", total_tokens)
)
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
Expand All @@ -87,13 +87,13 @@ def chat_streamly(self, system, history, gen_conf):

class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1"
if not base_url: base_url = "https://api.openai.com/v1"
super().__init__(key, model_name, base_url)


class MoonshotChat(Base):
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
if not base_url: base_url="https://api.moonshot.cn/v1"
if not base_url: base_url = "https://api.moonshot.cn/v1"
super().__init__(key, model_name, base_url)


Expand All @@ -108,7 +108,7 @@ def __init__(self, key=None, model_name="", base_url=""):

class DeepSeekChat(Base):
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
if not base_url: base_url="https://api.deepseek.com/v1"
if not base_url: base_url = "https://api.deepseek.com/v1"
super().__init__(key, model_name, base_url)


Expand Down Expand Up @@ -178,14 +178,14 @@ def chat_streamly(self, system, history, gen_conf):
stream=True,
**self._format_params(gen_conf))
for resp in response:
if not resp.choices:continue
if not resp.choices: continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
total_tokens = (
(
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
if not hasattr(resp, "usage")
else resp.usage["total_tokens"]
Expand Down Expand Up @@ -252,7 +252,8 @@ def chat_streamly(self, system, history, gen_conf):
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans
else:
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
"Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)

Expand Down Expand Up @@ -298,7 +299,7 @@ def chat_streamly(self, system, history, gen_conf):
**gen_conf
)
for resp in response:
if not resp.choices[0].delta.content:continue
if not resp.choices[0].delta.content: continue
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
Expand Down Expand Up @@ -411,15 +412,15 @@ def __init__(self, key, model_name):
self.client = Client(port=12345, protocol="grpc", asyncio=True)

def _prepare_prompt(self, system, history, gen_conf):
from rag.svr.jina_server import Prompt,Generation
from rag.svr.jina_server import Prompt, Generation
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
return Prompt(message=history, gen_conf=gen_conf)

def _stream_response(self, endpoint, prompt):
from rag.svr.jina_server import Prompt,Generation
from rag.svr.jina_server import Prompt, Generation
answer = ""
try:
res = self.client.stream_doc(
Expand Down Expand Up @@ -463,10 +464,10 @@ def __init__(self, key, model_name, base_url='https://ark.cn-beijing.volces.com/

class MiniMaxChat(Base):
def __init__(
self,
key,
model_name,
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
self,
key,
model_name,
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
):
if not base_url:
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
Expand Down Expand Up @@ -583,7 +584,7 @@ def chat_streamly(self, system, history, gen_conf):
messages=history,
**gen_conf)
for resp in response:
if not resp.choices or not resp.choices[0].delta.content:continue
if not resp.choices or not resp.choices[0].delta.content: continue
ans += resp.choices[0].delta.content
total_tokens += 1
if resp.choices[0].finish_reason == "length":
Expand Down Expand Up @@ -620,19 +621,18 @@ def chat(self, system, history, gen_conf):
gen_conf["topP"] = gen_conf["top_p"]
_ = gen_conf.pop("top_p")
for item in history:
if not isinstance(item["content"],list) and not isinstance(item["content"],tuple):
item["content"] = [{"text":item["content"]}]

if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
item["content"] = [{"text": item["content"]}]

try:
# Send the message to the model, using a basic inference configuration.
response = self.client.converse(
modelId=self.model_name,
messages=history,
inferenceConfig=gen_conf,
system=[{"text": (system if system else "Answer the user's message.")}] ,
system=[{"text": (system if system else "Answer the user's message.")}],
)

# Extract and print the response text.
ans = response["output"]["message"]["content"][0]["text"]
return ans, num_tokens_from_string(ans)
Expand All @@ -652,9 +652,9 @@ def chat_streamly(self, system, history, gen_conf):
gen_conf["topP"] = gen_conf["top_p"]
_ = gen_conf.pop("top_p")
for item in history:
if not isinstance(item["content"],list) and not isinstance(item["content"],tuple):
item["content"] = [{"text":item["content"]}]
if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
item["content"] = [{"text": item["content"]}]

if self.model_name.split('.')[0] == 'ai21':
try:
response = self.client.converse(
Expand Down Expand Up @@ -684,7 +684,7 @@ def chat_streamly(self, system, history, gen_conf):
if "contentBlockDelta" in resp:
ans += resp["contentBlockDelta"]["delta"]["text"]
yield ans

except (ClientError, Exception) as e:
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"

Expand All @@ -693,22 +693,21 @@ def chat_streamly(self, system, history, gen_conf):

class GeminiChat(Base):

def __init__(self, key, model_name,base_url=None):
from google.generativeai import client,GenerativeModel
def __init__(self, key, model_name, base_url=None):
from google.generativeai import client, GenerativeModel

client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = 'models/' + model_name
self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client


def chat(self,system,history,gen_conf):

def chat(self, system, history, gen_conf):
from google.generativeai.types import content_types

if system:
self.model._system_instruction = content_types.to_content(system)

if 'max_tokens' in gen_conf:
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
for k in list(gen_conf.keys()):
Expand All @@ -717,9 +716,11 @@ def chat(self,system,history,gen_conf):
for item in history:
if 'role' in item and item['role'] == 'assistant':
item['role'] = 'model'
if 'content' in item :
if 'role' in item and item['role'] == 'system':
item['role'] = 'user'
if 'content' in item:
item['parts'] = item.pop('content')

try:
response = self.model.generate_content(
history,
Expand All @@ -731,7 +732,7 @@ def chat(self,system,history,gen_conf):

def chat_streamly(self, system, history, gen_conf):
from google.generativeai.types import content_types

if system:
self.model._system_instruction = content_types.to_content(system)
if 'max_tokens' in gen_conf:
Expand All @@ -742,25 +743,25 @@ def chat_streamly(self, system, history, gen_conf):
for item in history:
if 'role' in item and item['role'] == 'assistant':
item['role'] = 'model'
if 'content' in item :
if 'content' in item:
item['parts'] = item.pop('content')
ans = ""
try:
response = self.model.generate_content(
history,
generation_config=gen_conf,stream=True)
generation_config=gen_conf, stream=True)
for resp in response:
ans += resp.text
yield ans

except Exception as e:
yield ans + "\n**ERROR**: " + str(e)

yield response._chunks[-1].usage_metadata.total_token_count
yield response._chunks[-1].usage_metadata.total_token_count


class GroqChat:
def __init__(self, key, model_name,base_url=''):
def __init__(self, key, model_name, base_url=''):
self.client = Groq(api_key=key)
self.model_name = model_name

Expand Down Expand Up @@ -942,7 +943,7 @@ def chat_streamly(self, system, history, gen_conf):
class LeptonAIChat(Base):
def __init__(self, key, model_name, base_url=None):
if not base_url:
base_url = os.path.join("https://"+model_name+".lepton.run","api","v1")
base_url = os.path.join("https://" + model_name + ".lepton.run", "api", "v1")
super().__init__(key, model_name, base_url)


Expand Down Expand Up @@ -1058,7 +1059,7 @@ def chat(self, system, history, gen_conf):
)

_gen_conf = {}
_history = [{k.capitalize(): v for k, v in item.items() } for item in history]
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
if system:
_history.insert(0, {"Role": "system", "Content": system})
if "temperature" in gen_conf:
Expand All @@ -1084,7 +1085,7 @@ def chat_streamly(self, system, history, gen_conf):
)

_gen_conf = {}
_history = [{k.capitalize(): v for k, v in item.items() } for item in history]
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
if system:
_history.insert(0, {"Role": "system", "Content": system})

Expand Down Expand Up @@ -1121,7 +1122,7 @@ def chat_streamly(self, system, history, gen_conf):

class SparkChat(Base):
def __init__(
self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
):
if not base_url:
base_url = "https://spark-api-open.xf-yun.com/v1"
Expand All @@ -1141,26 +1142,27 @@ def __init__(self, key, model_name, base_url=None):
import qianfan

key = json.loads(key)
ak = key.get("yiyan_ak","")
sk = key.get("yiyan_sk","")
self.client = qianfan.ChatCompletion(ak=ak,sk=sk)
ak = key.get("yiyan_ak", "")
sk = key.get("yiyan_sk", "")
self.client = qianfan.ChatCompletion(ak=ak, sk=sk)
self.model_name = model_name.lower()
self.system = ""

def chat(self, system, history, gen_conf):
if system:
self.system = system
gen_conf["penalty_score"] = (
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
) + 1
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
0)) / 2
) + 1
if "max_tokens" in gen_conf:
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
ans = ""

try:
response = self.client.do(
model=self.model_name,
messages=history,
model=self.model_name,
messages=history,
system=self.system,
**gen_conf
).body
Expand All @@ -1174,17 +1176,18 @@ def chat_streamly(self, system, history, gen_conf):
if system:
self.system = system
gen_conf["penalty_score"] = (
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
) + 1
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
0)) / 2
) + 1
if "max_tokens" in gen_conf:
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
ans = ""
total_tokens = 0

try:
response = self.client.do(
model=self.model_name,
messages=history,
model=self.model_name,
messages=history,
system=self.system,
stream=True,
**gen_conf
Expand Down Expand Up @@ -1415,4 +1418,3 @@ def chat_streamly(self, system, history, gen_conf):
yield ans + "\n**ERROR**: " + str(e)

yield response._chunks[-1].usage_metadata.total_token_count