From 23924adb0b05a0eec47257b536f99807167f0ff0 Mon Sep 17 00:00:00 2001 From: TKS <32640296+bigsk1@users.noreply.github.com> Date: Tue, 25 Jun 2024 06:31:26 -0700 Subject: [PATCH] async update --- app/app.py | 109 +++++++++++++++++++++----------------- app/app_logic.py | 2 +- app/main.py | 7 ++- app/static/css/styles.css | 60 +++++++++++++++------ app/static/js/scripts.js | 60 ++++++++++++++++----- app/templates/index.html | 20 +++---- 6 files changed, 170 insertions(+), 88 deletions(-) diff --git a/app/app.py b/app/app.py index 36ea07b..1afe6b4 100644 --- a/app/app.py +++ b/app/app.py @@ -149,8 +149,7 @@ def open_file(filepath): # Function to play audio using PyAudio async def play_audio(file_path): - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, sync_play_audio, file_path) + await asyncio.to_thread(sync_play_audio, file_path) def sync_play_audio(file_path): print("Starting audio playback") @@ -215,10 +214,10 @@ async def process_and_play(prompt, audio_file_pth): await send_message_to_clients(json.dumps({"action": "ai_stop_speaking"})) else: print("Error: Audio file not found.") - else: - tts_model = xtts_model + else: # XTTS try: - outputs = await tts_model.synthesize( + tts_model = xtts_model + outputs = await asyncio.to_thread(tts_model.synthesize, prompt, xtts_config, speaker_wav=audio_file_pth, @@ -389,6 +388,7 @@ def adjust_prompt(mood): def chatgpt_streamed(user_input, system_message, mood_prompt, conversation_history): full_response = "" + print(f"Debug: chatgpt_streamed started. MODEL_PROVIDER: {MODEL_PROVIDER}") if MODEL_PROVIDER == 'ollama': headers = {'Content-Type': 'application/json'} @@ -399,6 +399,7 @@ def chatgpt_streamed(user_input, system_message, mood_prompt, conversation_histo "options": {"num_predict": -2, "temperature": 1.0} } try: + print(f"Debug: Sending request to Ollama: {OLLAMA_BASE_URL}/v1/chat/completions") response = requests.post(f'{OLLAMA_BASE_URL}/v1/chat/completions', headers=headers, json=payload, stream=True, timeout=30) response.raise_for_status() @@ -426,12 +427,14 @@ def chatgpt_streamed(user_input, system_message, mood_prompt, conversation_histo except requests.exceptions.RequestException as e: full_response = f"Error connecting to Ollama model: {e}" + print(f"Debug: Ollama error - {e}") elif MODEL_PROVIDER == 'openai': messages = [{"role": "system", "content": system_message + "\n" + mood_prompt}] + conversation_history + [{"role": "user", "content": user_input}] headers = {'Authorization': f'Bearer {OPENAI_API_KEY}', 'Content-Type': 'application/json'} payload = {"model": OPENAI_MODEL, "messages": messages, "stream": True} try: + print(f"Debug: Sending request to OpenAI: {OPENAI_BASE_URL}") response = requests.post(OPENAI_BASE_URL, headers=headers, json=payload, stream=True, timeout=30) response.raise_for_status() @@ -461,7 +464,9 @@ def chatgpt_streamed(user_input, system_message, mood_prompt, conversation_histo except requests.exceptions.RequestException as e: full_response = f"Error connecting to OpenAI model: {e}" + print(f"Debug: OpenAI error - {e}") + print(f"Debug: chatgpt_streamed completed. Response length: {len(full_response)}") return full_response def transcribe_with_whisper(audio_file): @@ -505,7 +510,7 @@ def record_audio(file_path, silence_threshold=512, silence_duration=4.0, chunk_s wf.writeframes(b''.join(frames)) wf.close() -def execute_once(question_prompt): +async def execute_once(question_prompt): temp_image_path = os.path.join(output_dir, 'temp_img.jpg') # Determine the audio file format based on the TTS provider @@ -519,8 +524,8 @@ def execute_once(question_prompt): temp_audio_path = os.path.join(output_dir, 'temp_audio.wav') # Use wav for XTTS max_char_length = 250 # Set a lower limit for XTTS - image_path = take_screenshot(temp_image_path) - response = analyze_image(image_path, question_prompt) + image_path = await take_screenshot(temp_image_path) + response = await analyze_image(image_path, question_prompt) text_response = response.get('choices', [{}])[0].get('message', {}).get('content', 'No response received.') # Truncate response based on the TTS provider's limit @@ -529,40 +534,40 @@ def execute_once(question_prompt): print(text_response) - generate_speech(text_response, temp_audio_path) + await generate_speech(text_response, temp_audio_path) if TTS_PROVIDER == 'elevenlabs': # Convert MP3 to WAV if ElevenLabs is used temp_wav_path = os.path.join(output_dir, 'temp_output.wav') audio = AudioSegment.from_mp3(temp_audio_path) audio.export(temp_wav_path, format="wav") - play_audio(temp_wav_path) + await play_audio(temp_wav_path) else: - play_audio(temp_audio_path) + await play_audio(temp_audio_path) os.remove(image_path) - -def execute_screenshot_and_analyze(): +async def execute_screenshot_and_analyze(): question_prompt = "What do you see in this image? Keep it short but detailed and answer any follow up questions about it" print("Taking screenshot and analyzing...") - execute_once(question_prompt) + await execute_once(question_prompt) print("\nReady for the next question....") - -def take_screenshot(temp_image_path): - time.sleep(5) + +async def take_screenshot(temp_image_path): + await asyncio.sleep(5) screenshot = ImageGrab.grab() screenshot = screenshot.resize((1024, 1024)) screenshot.save(temp_image_path, 'JPEG') return temp_image_path -def encode_image(image_path): +# Encode Image +async def encode_image(image_path): with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') -def analyze_image(image_path, question_prompt): - encoded_image = encode_image(image_path) - +# Analyze Image +async def analyze_image(image_path, question_prompt): + encoded_image = await encode_image(image_path) if MODEL_PROVIDER == 'ollama': headers = {'Content-Type': 'application/json'} payload = { @@ -572,15 +577,17 @@ def analyze_image(image_path, question_prompt): "stream": False } try: - response = requests.post(f'{OLLAMA_BASE_URL}/api/generate', headers=headers, json=payload, timeout=30) - print(f"Response status code: {response.status_code}") - if response.status_code == 200: - return {"choices": [{"message": {"content": response.json().get('response', 'No response received.')}}]} - elif response.status_code == 404: - return {"choices": [{"message": {"content": "The llava model is not available on this server."}}]} - else: - response.raise_for_status() - except requests.exceptions.RequestException as e: + async with aiohttp.ClientSession() as session: + async with session.post(f'{OLLAMA_BASE_URL}/api/generate', headers=headers, json=payload, timeout=30) as response: + print(f"Response status code: {response.status}") + if response.status == 200: + response_json = await response.json() + return {"choices": [{"message": {"content": response_json.get('response', 'No response received.')}}]} + elif response.status == 404: + return {"choices": [{"message": {"content": "The llava model is not available on this server."}}]} + else: + response.raise_for_status() + except aiohttp.ClientError as e: print(f"Request failed: {e}") return {"choices": [{"message": {"content": "Failed to process the image with the llava model."}}]} else: @@ -594,29 +601,35 @@ def analyze_image(image_path, question_prompt): } payload = {"model": OPENAI_MODEL, "temperature": 0.5, "messages": [message], "max_tokens": 1000} try: - response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=30) - response.raise_for_status() - return response.json() - except requests.exceptions.RequestException as e: + async with aiohttp.ClientSession() as session: + async with session.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=30) as response: + response.raise_for_status() + return await response.json() + except aiohttp.ClientError as e: print(f"Request failed: {e}") return {"choices": [{"message": {"content": "Failed to process the image with the OpenAI model."}}]} -def generate_speech(text, temp_audio_path): + +async def generate_speech(text, temp_audio_path): if TTS_PROVIDER == 'openai': headers = {"Content-Type": "application/json", "Authorization": f"Bearer {OPENAI_API_KEY}"} payload = {"model": "tts-1", "voice": OPENAI_TTS_VOICE, "input": text, "response_format": "wav"} - response = requests.post(OPENAI_TTS_URL, headers=headers, json=payload, timeout=30) - if response.status_code == 200: - with open(temp_audio_path, "wb") as audio_file: - audio_file.write(response.content) - else: - print(f"Failed to generate speech: {response.status_code} - {response.text}") + async with aiohttp.ClientSession() as session: + async with session.post(OPENAI_TTS_URL, headers=headers, json=payload, timeout=30) as response: + if response.status == 200: + with open(temp_audio_path, "wb") as audio_file: + audio_file.write(await response.read()) + else: + print(f"Failed to generate speech: {response.status} - {await response.text()}") + elif TTS_PROVIDER == 'elevenlabs': - elevenlabs_text_to_speech(text, temp_audio_path) - else: + await elevenlabs_text_to_speech(text, temp_audio_path) + + else: # XTTS tts_model = xtts_model try: - outputs = tts_model.synthesize( + outputs = await asyncio.to_thread( + tts_model.synthesize, text, xtts_config, speaker_wav=character_audio_file, @@ -632,7 +645,7 @@ def generate_speech(text, temp_audio_path): except Exception as e: print(f"Error during XTTS audio generation: {e}") -def user_chatbot_conversation(): +async def user_chatbot_conversation(): conversation_history = [] base_system_message = open_file(character_prompt_file) quit_phrases = ["quit", "Quit", "Quit.", "Exit.", "exit", "Exit", "leave", "Leave."] @@ -659,7 +672,7 @@ def user_chatbot_conversation(): conversation_history.append({"role": "user", "content": user_input}) if any(phrase in user_input.lower() for phrase in screenshot_phrases): - execute_screenshot_and_analyze() + await execute_screenshot_and_analyze() # Note the 'await' here continue mood = analyze_mood(user_input) @@ -672,11 +685,11 @@ def user_chatbot_conversation(): if len(sanitized_response) > 400: sanitized_response = sanitized_response[:400] + "..." prompt2 = sanitized_response - process_and_play(prompt2, character_audio_file) + await process_and_play(prompt2, character_audio_file) # Note the 'await' here if len(conversation_history) > 20: conversation_history = conversation_history[-20:] except KeyboardInterrupt: print("Quitting the conversation...") if __name__ == "__main__": - user_chatbot_conversation() + asyncio.run(user_chatbot_conversation()) diff --git a/app/app_logic.py b/app/app_logic.py index 8b10484..b7845d4 100644 --- a/app/app_logic.py +++ b/app/app_logic.py @@ -93,7 +93,7 @@ async def conversation_loop(): break if any(phrase in user_input.lower() for phrase in screenshot_phrases): - execute_screenshot_and_analyze() + await execute_screenshot_and_analyze() continue try: diff --git a/app/main.py b/app/main.py index 8ced303..f938cb4 100644 --- a/app/main.py +++ b/app/main.py @@ -10,6 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware from .shared import clients, get_current_character, set_current_character from .app_logic import start_conversation, stop_conversation, set_env_variable +# from .app import user_chatbot_conversation app = FastAPI() @@ -25,6 +26,10 @@ allow_headers=["*"], ) +# @app.on_event("startup") +# async def startup_event(): +# asyncio.create_task(user_chatbot_conversation()) + @app.get("/") async def get(request: Request): model_provider = os.getenv("MODEL_PROVIDER") @@ -85,7 +90,7 @@ async def websocket_endpoint(websocket: WebSocket): await start_conversation() elif message["action"] == "set_character": set_current_character(message["character"]) - await websocket.send_json({"message": f"Character set to {message['character']}"}) + await websocket.send_json({"message": f"Character: {message['character']}"}) elif message["action"] == "set_provider": set_env_variable("MODEL_PROVIDER", message["provider"]) elif message["action"] == "set_tts": diff --git a/app/static/css/styles.css b/app/static/css/styles.css index 8c634f2..3e5529c 100644 --- a/app/static/css/styles.css +++ b/app/static/css/styles.css @@ -1,16 +1,18 @@ :root { - --bg-color: #e4e2e2; + --bg-color: #f0f2f5; --text-color: #333333; - --primary-color: #cacaca; - --secondary-color: #ecebeb; - --border-color: #e0e0e0; + --primary-color: #1877f2; + --secondary-color: #e4e6eb; + --border-color: #dddfe2; --shadow-color: rgba(0, 0, 0, 0.1); --conversation-bg: #ffffff; - --ai-message-color: #2b6fbd; - --user-message-color: #333333; - --button-bg: #4a90e2; + --ai-message-color: #1877f2; + --user-message-color: #ffffff; + --ai-message-bg: #e6f2ff; + --user-message-bg: #f0f2f5; + --button-bg: #1877f2; --button-text: #ffffff; - --button-hover: #3a7dca; + --button-hover: #166fe5; --select-bg: #ffffff; --select-text: #333333; } @@ -77,8 +79,8 @@ main { #conversation { width: 100%; - height: 50vh; - min-height: 300px; + height: 60vh; + min-height: 350px; border: 1px solid var(--border-color); border-radius: 8px; padding: 1rem; @@ -87,6 +89,7 @@ main { margin-bottom: 1rem; box-shadow: 0 0 10px var(--shadow-color); box-sizing: border-box; + position: relative; } .ai-message, .user-message { @@ -100,13 +103,20 @@ main { background-color: var(--secondary-color); color: var(--ai-message-color); align-self: flex-start; + margin-left: 0; + margin-right: 0; + max-width: calc(100% - 120px); } .user-message { background-color: var(--primary-color); color: var(--user-message-color); align-self: flex-end; - margin-left: auto; + margin-left: 60%; +} + +#messages { + padding-bottom: 20px; /* Add padding to the bottom of the messages container */ } .controls { @@ -125,13 +135,18 @@ button { border: none; border-radius: 4px; cursor: pointer; - transition: background-color 0.3s; + transition: background-color 0.3s, transform 0.1s; flex-grow: 1; max-width: 200px; } button:hover { background-color: var(--button-hover); + transform: translateY(-1px); +} + +button:active { + transform: translateY(1px); } .settings { @@ -155,8 +170,14 @@ select { background-color: var(--select-bg); color: var(--select-text); border: 1px solid var(--border-color); - border-radius: 4px; + border-radius: 6px; width: 100%; + transition: border-color 0.3s; +} + +select:focus { + outline: none; + border-color: var(--primary-color); } footer { @@ -203,12 +224,19 @@ h1 { } */ #voice-animation { - width: 60px; - height: 60px; - margin: 0 50px; + position: sticky; + bottom: 10px; + left: 50%; + transform: translateX(-50%); + width: 150px; + height: 150px; display: flex; justify-content: center; align-items: center; + z-index: 10; + background-color: var(--conversation-bg); + border-radius: 50%; + box-shadow: 0 0 10px var(--shadow-color); } #voice-animation svg { diff --git a/app/static/js/scripts.js b/app/static/js/scripts.js index 02c95bc..042fa3b 100644 --- a/app/static/js/scripts.js +++ b/app/static/js/scripts.js @@ -8,6 +8,9 @@ document.addEventListener("DOMContentLoaded", function() { const clearButton = document.getElementById('clear-conversation-btn'); const messages = document.getElementById('messages'); + let aiMessageQueue = []; + let isAISpeaking = false; + websocket.onopen = function(event) { console.log("WebSocket is open now."); startButton.disabled = false; @@ -24,7 +27,7 @@ document.addEventListener("DOMContentLoaded", function() { }; websocket.onmessage = function(event) { - console.log("Received message:", event.data); // Debugging line + console.log("Received message:", event.data); let data; try { data = JSON.parse(event.data); @@ -34,16 +37,51 @@ document.addEventListener("DOMContentLoaded", function() { } if (data.action === "ai_start_speaking") { + isAISpeaking = true; showVoiceAnimation(); + setTimeout(processQueuedMessages, 100); } else if (data.action === "ai_stop_speaking") { + isAISpeaking = false; hideVoiceAnimation(); + processQueuedMessages(); } else if (data.message) { - displayMessage(data.message); + if (data.message.startsWith('You:')) { + displayMessage(data.message); + } else { + aiMessageQueue.push(data.message); + if (!isAISpeaking) { + processQueuedMessages(); + } + } } else { displayMessage(event.data); } }; + function processQueuedMessages() { + while (aiMessageQueue.length > 0 && !isAISpeaking) { + displayMessage(aiMessageQueue.shift()); + } + } + + function showVoiceAnimation() { + voiceAnimation.classList.remove('hidden'); + adjustScrollPosition(); + } + + function hideVoiceAnimation() { + voiceAnimation.classList.add('hidden'); + adjustScrollPosition(); + processQueuedMessages(); + } + + function adjustScrollPosition() { + const conversation = document.getElementById('conversation'); + if (isAISpeaking) { + conversation.scrollTop = conversation.scrollHeight; + } + } + function displayMessage(message) { let formattedMessage = message; if (formattedMessage.includes('```')) { @@ -60,9 +98,7 @@ document.addEventListener("DOMContentLoaded", function() { } messageElement.innerHTML = formattedMessage; messages.appendChild(messageElement); - - const conversation = document.getElementById('conversation'); - conversation.scrollTop = conversation.scrollHeight; + adjustScrollPosition(); } startButton.addEventListener('click', function() { @@ -80,13 +116,13 @@ document.addEventListener("DOMContentLoaded", function() { messages.innerHTML = ''; }); - function showVoiceAnimation() { - voiceAnimation.classList.remove('hidden'); - } - - function hideVoiceAnimation() { - voiceAnimation.classList.add('hidden'); - } + messages.addEventListener('scroll', function() { + if (isAISpeaking) { + const conversation = document.getElementById('conversation'); + const isScrolledToBottom = conversation.scrollHeight - conversation.clientHeight <= conversation.scrollTop + 1; + voiceAnimation.style.opacity = isScrolledToBottom ? '1' : '0'; + } + }); function setProvider() { const selectedProvider = document.getElementById('provider-select').value; diff --git a/app/templates/index.html b/app/templates/index.html index 590ff25..6c6bea2 100644 --- a/app/templates/index.html +++ b/app/templates/index.html @@ -14,16 +14,6 @@

Voice Chat AI

-