Skip to content

Commit

Permalink
Add Gemini Vision, convert incoming audio (#1158)
Browse files Browse the repository at this point in the history
* Convert incoming audio to wav

* Add vision to Google provider, clean up others
  • Loading branch information
Josh-XT authored Apr 3, 2024
1 parent 73707f4 commit bcc4df7
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 57 deletions.
23 changes: 22 additions & 1 deletion agixt/endpoints/Completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import base64
import uuid
import json
import requests
from fastapi import APIRouter, Depends, Header
from Interactions import Interactions, get_tokens, log_interaction
from ApiClient import Agent, verify_api_key, get_api_client
Expand All @@ -19,6 +20,7 @@
TextToSpeech,
ImageCreation,
)
from pydub import AudioSegment

app = APIRouter()

Expand Down Expand Up @@ -177,8 +179,27 @@ async def chat_completion(
if "url" in message["audio_url"]
else message["audio_url"]
)
# If it is not a url, we need to find the file type and convert with pydub
if not audio_url.startswith("http"):
file_type = audio_url.split(",")[0].split("/")[1].split(";")[0]
audio_data = base64.b64decode(audio_url.split(",")[1])
audio_path = f"./WORKSPACE/{uuid.uuid4().hex}.{file_type}"
with open(audio_path, "wb") as f:
f.write(audio_data)
audio_url = audio_path
else:
# Download the audio file from the url, get the file type and convert to wav
audio_type = audio_url.split(".")[-1]
audio_url = f"./WORKSPACE/{uuid.uuid4().hex}.{audio_type}"
audio_data = requests.get(audio_url).content
with open(audio_url, "wb") as f:
f.write(audio_data)
wav_file = f"./WORKSPACE/{uuid.uuid4().hex}.wav"
AudioSegment.from_file(audio_url).set_frame_rate(16000).export(
wav_file, format="wav"
)
transcribed_audio = agent.agent.transcribe_audio(
file=audio_url, prompt=new_prompt
audio_path=wav_file
)
new_prompt += transcribed_audio
if "video_url" in message:
Expand Down
2 changes: 1 addition & 1 deletion agixt/providers/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(

@staticmethod
def services():
return ["llm"]
return ["llm", "vision"]

async def inference(self, prompt, tokens: int = 0, images: list = []):
if (
Expand Down
2 changes: 1 addition & 1 deletion agixt/providers/ezlocalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(

@staticmethod
def services():
return ["llm", "tts", "transcription", "translation"]
return ["llm", "tts", "transcription", "translation", "vision"]

def rotate_uri(self):
self.FAILURES.append(self.API_URI)
Expand Down
22 changes: 19 additions & 3 deletions agixt/providers/google.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import os
from pathlib import Path

try:
import google.generativeai as genai # Primary import attempt
Expand Down Expand Up @@ -41,18 +42,33 @@ def __init__(

@staticmethod
def services():
return ["llm", "tts"]
return ["llm", "tts", "vision"]

async def inference(self, prompt, tokens: int = 0, images: list = []):
if not self.GOOGLE_API_KEY or self.GOOGLE_API_KEY == "None":
return "Please set your Google API key in the Agent Management page."
try:
genai.configure(api_key=self.GOOGLE_API_KEY)
model = genai.GenerativeModel(self.AI_MODEL)
new_max_tokens = int(self.MAX_TOKENS) - tokens
generation_config = genai.types.GenerationConfig(
max_output_tokens=new_max_tokens, temperature=float(self.AI_TEMPERATURE)
)
model = genai.GenerativeModel(
model_name=self.AI_MODEL if not images else "gemini-pro-vision",
generation_config=generation_config,
)
new_max_tokens = int(self.MAX_TOKENS) - tokens
new_prompt = []
if images:
for image in images:
file_extension = Path(image).suffix
new_prompt.append(
{
"mime_type": f"image/{file_extension}",
"data": Path(image).read_bytes(),
}
)
new_prompt.append(prompt)
prompt = new_prompt
response = await asyncio.to_thread(
model.generate_content,
contents=prompt,
Expand Down
79 changes: 29 additions & 50 deletions agixt/providers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,26 @@
import io
from PIL import Image

MODELS = {
"HuggingFaceH4/starchat-beta": 8192,
"OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5": 1512,
"bigcode/starcoderplus": 8192,
"bigcode/starcoder": 8192,
"bigcode/santacoder": 1512,
"EleutherAI/gpt-neox-20b": 1512,
"EleutherAI/gpt-neo-1.3B": 2048,
"RedPajama-INCITE-Instruct-3B-v1": 2048,
}


class HuggingfaceProvider:
def __init__(
self,
MODEL_PATH: str = "HuggingFaceH4/starchat-beta",
HUGGINGFACE_API_KEY: str = None,
HUGGINGFACE_API_URL: str = "https://api-inference.huggingface.co/models/{model}",
STABLE_DIFFUSION_MODEL: str = "runwayml/stable-diffusion-v1-5",
STABLE_DIFFUSION_API_URL: str = "https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5",
AI_MODEL: str = "starchat",
AI_MODEL: str = "HuggingFaceH4/zephyr-7b-beta",
stop=["<|end|>"],
MAX_TOKENS: int = 1024,
AI_TEMPERATURE: float = 0.7,
MAX_RETRIES: int = 15,
**kwargs,
):
self.requirements = []
self.MODEL_PATH = MODEL_PATH
self.AI_MODEL = AI_MODEL
self.HUGGINGFACE_API_KEY = HUGGINGFACE_API_KEY
self.HUGGINGFACE_API_URL = HUGGINGFACE_API_URL
self.HUGGINGFACE_API_URL = (
f"https://api-inference.huggingface.co/models/{self.AI_MODEL}"
)
if (
STABLE_DIFFUSION_MODEL != "runwayml/stable-diffusion-v1-5"
and STABLE_DIFFUSION_API_URL.startswith(
Expand All @@ -50,7 +39,6 @@ def __init__(
self.STABLE_DIFFUSION_API_URL = STABLE_DIFFUSION_API_URL
self.AI_TEMPERATURE = AI_TEMPERATURE
self.MAX_TOKENS = MAX_TOKENS
self.AI_MODEL = AI_MODEL
self.stop = stop
self.MAX_RETRIES = MAX_RETRIES
self.parameters = kwargs
Expand All @@ -59,29 +47,30 @@ def __init__(
def services():
return ["llm", "tts", "image"]

def get_url(self) -> str:
return self.HUGGINGFACE_API_URL.replace("{model}", self.MODEL_PATH)

def get_max_length(self):
if self.MODEL_PATH in MODELS:
return MODELS[self.MODEL_PATH]
return 4096

def get_max_new_tokens(self, input_length: int = 0) -> int:
return min(self.get_max_length() - input_length, self.MAX_TOKENS)

def request(self, inputs, **kwargs):
payload = {"inputs": inputs, "parameters": {**kwargs}}
async def inference(self, prompt, tokens: int = 0, images: list = []):
payload = {
"inputs": prompt,
"parameters": {
"temperature": self.AI_TEMPERATURE,
"max_new_tokens": (int(self.MAX_TOKENS) - tokens),
"return_full_text": False,
"stop": self.stop,
**self.parameters,
},
}
headers = {}
if self.HUGGINGFACE_API_KEY:
headers["Authorization"] = f"Bearer {self.HUGGINGFACE_API_KEY}"

tries = 0
while True:
tries += 1
if tries > self.MAX_RETRIES:
raise ValueError(f"Reached max retries: {self.MAX_RETRIES}")
response = requests.post(self.get_url(), json=payload, headers=headers)
response = requests.post(
self.HUGGINGFACE_API_URL,
json=payload,
headers=headers,
)
if response.status_code == 429:
logging.info(
f"Server Error {response.status_code}: Getting rate-limited / wait for {tries} seconds."
Expand All @@ -96,26 +85,16 @@ def request(self, inputs, **kwargs):
raise ValueError(f"Error {response.status_code}: {response.text}")
else:
break

content_type = response.headers["Content-Type"]
if content_type == "application/json":
return response.json()

async def inference(self, prompt, tokens: int = 0, images: list = []):
result = self.request(
prompt,
temperature=self.AI_TEMPERATURE,
max_new_tokens=self.get_max_new_tokens(tokens),
return_full_text=False,
stop=self.stop,
**self.parameters,
)[0]["generated_text"]
if self.stop:
for stop_seq in self.stop:
find = result.find(stop_seq)
if find >= 0:
result = result[:find]
return result
response = response.json()
result = response[0]["generated_text"]
if self.stop:
for stop_seq in self.stop:
find = result.find(stop_seq)
if find >= 0:
result = result[:find]
return result

async def generate_image(
self,
Expand Down
10 changes: 9 additions & 1 deletion agixt/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,15 @@ def __init__(

@staticmethod
def services():
return ["llm", "tts", "image", "embeddings", "transcription", "translation"]
return [
"llm",
"tts",
"image",
"embeddings",
"transcription",
"translation",
"vision",
]

def rotate_uri(self):
self.FAILURES.append(self.API_URI)
Expand Down

0 comments on commit bcc4df7

Please sign in to comment.