From bcc4df7a848da9931878c42504b3aa22d3fc4e9a Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Tue, 2 Apr 2024 22:28:43 -0400 Subject: [PATCH] Add Gemini Vision, convert incoming audio (#1158) * Convert incoming audio to wav * Add vision to Google provider, clean up others --- agixt/endpoints/Completions.py | 23 +++++++++- agixt/providers/claude.py | 2 +- agixt/providers/ezlocalai.py | 2 +- agixt/providers/google.py | 22 ++++++++-- agixt/providers/huggingface.py | 79 +++++++++++++--------------------- agixt/providers/openai.py | 10 ++++- 6 files changed, 81 insertions(+), 57 deletions(-) diff --git a/agixt/endpoints/Completions.py b/agixt/endpoints/Completions.py index c5c2e66c4bd..8890e8de47c 100644 --- a/agixt/endpoints/Completions.py +++ b/agixt/endpoints/Completions.py @@ -2,6 +2,7 @@ import base64 import uuid import json +import requests from fastapi import APIRouter, Depends, Header from Interactions import Interactions, get_tokens, log_interaction from ApiClient import Agent, verify_api_key, get_api_client @@ -19,6 +20,7 @@ TextToSpeech, ImageCreation, ) +from pydub import AudioSegment app = APIRouter() @@ -177,8 +179,27 @@ async def chat_completion( if "url" in message["audio_url"] else message["audio_url"] ) + # If it is not a url, we need to find the file type and convert with pydub + if not audio_url.startswith("http"): + file_type = audio_url.split(",")[0].split("/")[1].split(";")[0] + audio_data = base64.b64decode(audio_url.split(",")[1]) + audio_path = f"./WORKSPACE/{uuid.uuid4().hex}.{file_type}" + with open(audio_path, "wb") as f: + f.write(audio_data) + audio_url = audio_path + else: + # Download the audio file from the url, get the file type and convert to wav + audio_type = audio_url.split(".")[-1] + audio_url = f"./WORKSPACE/{uuid.uuid4().hex}.{audio_type}" + audio_data = requests.get(audio_url).content + with open(audio_url, "wb") as f: + f.write(audio_data) + wav_file = f"./WORKSPACE/{uuid.uuid4().hex}.wav" + AudioSegment.from_file(audio_url).set_frame_rate(16000).export( + wav_file, format="wav" + ) transcribed_audio = agent.agent.transcribe_audio( - file=audio_url, prompt=new_prompt + audio_path=wav_file ) new_prompt += transcribed_audio if "video_url" in message: diff --git a/agixt/providers/claude.py b/agixt/providers/claude.py index df12f00103d..41e40f9064f 100644 --- a/agixt/providers/claude.py +++ b/agixt/providers/claude.py @@ -31,7 +31,7 @@ def __init__( @staticmethod def services(): - return ["llm"] + return ["llm", "vision"] async def inference(self, prompt, tokens: int = 0, images: list = []): if ( diff --git a/agixt/providers/ezlocalai.py b/agixt/providers/ezlocalai.py index 2c5fe17cf61..37f429047a1 100644 --- a/agixt/providers/ezlocalai.py +++ b/agixt/providers/ezlocalai.py @@ -55,7 +55,7 @@ def __init__( @staticmethod def services(): - return ["llm", "tts", "transcription", "translation"] + return ["llm", "tts", "transcription", "translation", "vision"] def rotate_uri(self): self.FAILURES.append(self.API_URI) diff --git a/agixt/providers/google.py b/agixt/providers/google.py index 6f9a53a2190..25320be05f6 100644 --- a/agixt/providers/google.py +++ b/agixt/providers/google.py @@ -1,5 +1,6 @@ import asyncio import os +from pathlib import Path try: import google.generativeai as genai # Primary import attempt @@ -41,18 +42,33 @@ def __init__( @staticmethod def services(): - return ["llm", "tts"] + return ["llm", "tts", "vision"] async def inference(self, prompt, tokens: int = 0, images: list = []): if not self.GOOGLE_API_KEY or self.GOOGLE_API_KEY == "None": return "Please set your Google API key in the Agent Management page." try: genai.configure(api_key=self.GOOGLE_API_KEY) - model = genai.GenerativeModel(self.AI_MODEL) - new_max_tokens = int(self.MAX_TOKENS) - tokens generation_config = genai.types.GenerationConfig( max_output_tokens=new_max_tokens, temperature=float(self.AI_TEMPERATURE) ) + model = genai.GenerativeModel( + model_name=self.AI_MODEL if not images else "gemini-pro-vision", + generation_config=generation_config, + ) + new_max_tokens = int(self.MAX_TOKENS) - tokens + new_prompt = [] + if images: + for image in images: + file_extension = Path(image).suffix + new_prompt.append( + { + "mime_type": f"image/{file_extension}", + "data": Path(image).read_bytes(), + } + ) + new_prompt.append(prompt) + prompt = new_prompt response = await asyncio.to_thread( model.generate_content, contents=prompt, diff --git a/agixt/providers/huggingface.py b/agixt/providers/huggingface.py index 8395ca20a65..4d6a3590be3 100644 --- a/agixt/providers/huggingface.py +++ b/agixt/providers/huggingface.py @@ -6,27 +6,14 @@ import io from PIL import Image -MODELS = { - "HuggingFaceH4/starchat-beta": 8192, - "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5": 1512, - "bigcode/starcoderplus": 8192, - "bigcode/starcoder": 8192, - "bigcode/santacoder": 1512, - "EleutherAI/gpt-neox-20b": 1512, - "EleutherAI/gpt-neo-1.3B": 2048, - "RedPajama-INCITE-Instruct-3B-v1": 2048, -} - class HuggingfaceProvider: def __init__( self, - MODEL_PATH: str = "HuggingFaceH4/starchat-beta", HUGGINGFACE_API_KEY: str = None, - HUGGINGFACE_API_URL: str = "https://api-inference.huggingface.co/models/{model}", STABLE_DIFFUSION_MODEL: str = "runwayml/stable-diffusion-v1-5", STABLE_DIFFUSION_API_URL: str = "https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5", - AI_MODEL: str = "starchat", + AI_MODEL: str = "HuggingFaceH4/zephyr-7b-beta", stop=["<|end|>"], MAX_TOKENS: int = 1024, AI_TEMPERATURE: float = 0.7, @@ -34,9 +21,11 @@ def __init__( **kwargs, ): self.requirements = [] - self.MODEL_PATH = MODEL_PATH + self.AI_MODEL = AI_MODEL self.HUGGINGFACE_API_KEY = HUGGINGFACE_API_KEY - self.HUGGINGFACE_API_URL = HUGGINGFACE_API_URL + self.HUGGINGFACE_API_URL = ( + f"https://api-inference.huggingface.co/models/{self.AI_MODEL}" + ) if ( STABLE_DIFFUSION_MODEL != "runwayml/stable-diffusion-v1-5" and STABLE_DIFFUSION_API_URL.startswith( @@ -50,7 +39,6 @@ def __init__( self.STABLE_DIFFUSION_API_URL = STABLE_DIFFUSION_API_URL self.AI_TEMPERATURE = AI_TEMPERATURE self.MAX_TOKENS = MAX_TOKENS - self.AI_MODEL = AI_MODEL self.stop = stop self.MAX_RETRIES = MAX_RETRIES self.parameters = kwargs @@ -59,29 +47,30 @@ def __init__( def services(): return ["llm", "tts", "image"] - def get_url(self) -> str: - return self.HUGGINGFACE_API_URL.replace("{model}", self.MODEL_PATH) - - def get_max_length(self): - if self.MODEL_PATH in MODELS: - return MODELS[self.MODEL_PATH] - return 4096 - - def get_max_new_tokens(self, input_length: int = 0) -> int: - return min(self.get_max_length() - input_length, self.MAX_TOKENS) - - def request(self, inputs, **kwargs): - payload = {"inputs": inputs, "parameters": {**kwargs}} + async def inference(self, prompt, tokens: int = 0, images: list = []): + payload = { + "inputs": prompt, + "parameters": { + "temperature": self.AI_TEMPERATURE, + "max_new_tokens": (int(self.MAX_TOKENS) - tokens), + "return_full_text": False, + "stop": self.stop, + **self.parameters, + }, + } headers = {} if self.HUGGINGFACE_API_KEY: headers["Authorization"] = f"Bearer {self.HUGGINGFACE_API_KEY}" - tries = 0 while True: tries += 1 if tries > self.MAX_RETRIES: raise ValueError(f"Reached max retries: {self.MAX_RETRIES}") - response = requests.post(self.get_url(), json=payload, headers=headers) + response = requests.post( + self.HUGGINGFACE_API_URL, + json=payload, + headers=headers, + ) if response.status_code == 429: logging.info( f"Server Error {response.status_code}: Getting rate-limited / wait for {tries} seconds." @@ -96,26 +85,16 @@ def request(self, inputs, **kwargs): raise ValueError(f"Error {response.status_code}: {response.text}") else: break - content_type = response.headers["Content-Type"] if content_type == "application/json": - return response.json() - - async def inference(self, prompt, tokens: int = 0, images: list = []): - result = self.request( - prompt, - temperature=self.AI_TEMPERATURE, - max_new_tokens=self.get_max_new_tokens(tokens), - return_full_text=False, - stop=self.stop, - **self.parameters, - )[0]["generated_text"] - if self.stop: - for stop_seq in self.stop: - find = result.find(stop_seq) - if find >= 0: - result = result[:find] - return result + response = response.json() + result = response[0]["generated_text"] + if self.stop: + for stop_seq in self.stop: + find = result.find(stop_seq) + if find >= 0: + result = result[:find] + return result async def generate_image( self, diff --git a/agixt/providers/openai.py b/agixt/providers/openai.py index dd217f34d59..09440ca151c 100644 --- a/agixt/providers/openai.py +++ b/agixt/providers/openai.py @@ -57,7 +57,15 @@ def __init__( @staticmethod def services(): - return ["llm", "tts", "image", "embeddings", "transcription", "translation"] + return [ + "llm", + "tts", + "image", + "embeddings", + "transcription", + "translation", + "vision", + ] def rotate_uri(self): self.FAILURES.append(self.API_URI)