diff --git a/backend/onyx/chat/prompt_builder/answer_prompt_builder.py b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py index b1beb211b12..603ba0a8fa7 100644 --- a/backend/onyx/chat/prompt_builder/answer_prompt_builder.py +++ b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py @@ -15,6 +15,7 @@ from onyx.llm.utils import build_content_with_imgs from onyx.llm.utils import check_message_tokens from onyx.llm.utils import message_to_prompt_and_imgs +from onyx.llm.utils import model_supports_image_input from onyx.natural_language_processing.utils import get_tokenizer from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK @@ -86,6 +87,7 @@ def __init__( provider_type=llm_config.model_provider, model_name=llm_config.model_name, ) + self.llm_config = llm_config self.llm_tokenizer_encode_func = cast( Callable[[str], list[int]], llm_tokenizer.encode ) @@ -94,12 +96,21 @@ def __init__( ( self.message_history, self.history_token_cnts, - ) = translate_history_to_basemessages(message_history) + ) = translate_history_to_basemessages( + message_history, + exclude_images=not model_supports_image_input( + self.llm_config.model_name, + self.llm_config.model_provider, + ), + ) self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None self.user_message_and_token_cnt = ( user_message, - check_message_tokens(user_message, self.llm_tokenizer_encode_func), + check_message_tokens( + user_message, + self.llm_tokenizer_encode_func, + ), ) self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = [] diff --git a/backend/onyx/chat/prompt_builder/utils.py b/backend/onyx/chat/prompt_builder/utils.py index 7c231234c34..6b2a003078b 100644 --- a/backend/onyx/chat/prompt_builder/utils.py +++ b/backend/onyx/chat/prompt_builder/utils.py @@ -11,6 +11,7 @@ def translate_onyx_msg_to_langchain( msg: ChatMessage | PreviousMessage, + exclude_images: bool = False, ) -> BaseMessage: files: list[InMemoryChatFile] = [] @@ -18,7 +19,9 @@ def translate_onyx_msg_to_langchain( # attached. Just ignore them for now. if not isinstance(msg, ChatMessage): files = msg.files - content = build_content_with_imgs(msg.message, files, message_type=msg.message_type) + content = build_content_with_imgs( + msg.message, files, message_type=msg.message_type, exclude_images=exclude_images + ) if msg.message_type == MessageType.SYSTEM: raise ValueError("System messages are not currently part of history") @@ -32,9 +35,12 @@ def translate_onyx_msg_to_langchain( def translate_history_to_basemessages( history: list[ChatMessage] | list["PreviousMessage"], + exclude_images: bool = False, ) -> tuple[list[BaseMessage], list[int]]: history_basemessages = [ - translate_onyx_msg_to_langchain(msg) for msg in history if msg.token_count != 0 + translate_onyx_msg_to_langchain(msg, exclude_images) + for msg in history + if msg.token_count != 0 ] history_token_counts = [msg.token_count for msg in history if msg.token_count != 0] return history_basemessages, history_token_counts diff --git a/backend/onyx/llm/utils.py b/backend/onyx/llm/utils.py index 55b77be7909..279eea40fd4 100644 --- a/backend/onyx/llm/utils.py +++ b/backend/onyx/llm/utils.py @@ -142,6 +142,7 @@ def build_content_with_imgs( img_urls: list[str] | None = None, b64_imgs: list[str] | None = None, message_type: MessageType = MessageType.USER, + exclude_images: bool = False, ) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type files = files or [] @@ -157,7 +158,7 @@ def build_content_with_imgs( message_main_content = _build_content(message, files) - if not img_files and not img_urls: + if exclude_images or (not img_files and not img_urls): return message_main_content return cast( @@ -382,9 +383,19 @@ def _strip_colon_from_model_name(model_name: str) -> str: return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name -def _find_model_obj( - model_map: dict, provider: str, model_names: list[str | None] -) -> dict | None: +def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | None: + stripped_model_name = _strip_extra_provider_from_model_name(model_name) + + model_names = [ + model_name, + _strip_extra_provider_from_model_name(model_name), + # Remove leading extra provider. Usually for cases where user has a + # customer model proxy which appends another prefix + # remove :XXXX from the end, if present. Needed for ollama. + _strip_colon_from_model_name(model_name), + _strip_colon_from_model_name(stripped_model_name), + ] + # Filter out None values and deduplicate model names filtered_model_names = [name for name in model_names if name] @@ -417,21 +428,10 @@ def get_llm_max_tokens( return GEN_AI_MAX_TOKENS try: - extra_provider_stripped_model_name = _strip_extra_provider_from_model_name( - model_name - ) model_obj = _find_model_obj( model_map, model_provider, - [ - model_name, - # Remove leading extra provider. Usually for cases where user has a - # customer model proxy which appends another prefix - extra_provider_stripped_model_name, - # remove :XXXX from the end, if present. Needed for ollama. - _strip_colon_from_model_name(model_name), - _strip_colon_from_model_name(extra_provider_stripped_model_name), - ], + model_name, ) if not model_obj: raise RuntimeError( @@ -523,3 +523,23 @@ def get_max_input_tokens( raise RuntimeError("No tokens for input for the LLM given settings") return input_toks + + +def model_supports_image_input(model_name: str, model_provider: str) -> bool: + model_map = get_model_map() + try: + model_obj = _find_model_obj( + model_map, + model_provider, + model_name, + ) + if not model_obj: + raise RuntimeError( + f"No litellm entry found for {model_provider}/{model_name}" + ) + return model_obj.get("supports_vision", False) + except Exception: + logger.exception( + f"Failed to get model object for {model_provider}/{model_name}" + ) + return False diff --git a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py index c1e83e598dc..f4e19e1c283 100644 --- a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py +++ b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py @@ -16,6 +16,7 @@ from onyx.llm.models import PreviousMessage from onyx.llm.utils import build_content_with_imgs from onyx.llm.utils import message_to_string +from onyx.llm.utils import model_supports_image_input from onyx.prompts.constants import GENERAL_SEP_PAT from onyx.tools.message import ToolCallSummary from onyx.tools.models import ToolResponse @@ -316,12 +317,22 @@ def build_next_prompt( for img in img_generation_response if img.image_data is not None ] - prompt_builder.update_user_prompt( - build_image_generation_user_prompt( - query=prompt_builder.get_user_message_content(), - img_urls=img_urls, - b64_imgs=b64_imgs, - ) + + user_prompt = build_image_generation_user_prompt( + query=prompt_builder.get_user_message_content(), + supports_image_input=model_supports_image_input( + prompt_builder.llm_config.model_name, + prompt_builder.llm_config.model_provider, + ), + prompts=[ + prompt + for response in img_generation_response + for prompt in response.revised_prompt + ], + img_urls=img_urls, + b64_imgs=b64_imgs, ) + prompt_builder.update_user_prompt(user_prompt) + return prompt_builder diff --git a/backend/onyx/tools/tool_implementations/images/prompt.py b/backend/onyx/tools/tool_implementations/images/prompt.py index c5bd7928c19..b45f9d15604 100644 --- a/backend/onyx/tools/tool_implementations/images/prompt.py +++ b/backend/onyx/tools/tool_implementations/images/prompt.py @@ -9,16 +9,34 @@ Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists. """ +IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES = """ +You have generated images based on the following query: "{query}". +The prompts used to create these images were: {prompts} + +Describe the two images you generated, summarizing the key elements and content in a sentence or two. +Be specific about what was generated and respond as if you have seen them, +without including any disclaimers or speculations. +""" + def build_image_generation_user_prompt( query: str, + supports_image_input: bool, img_urls: list[str] | None = None, b64_imgs: list[str] | None = None, + prompts: list[str] | None = None, ) -> HumanMessage: - return HumanMessage( - content=build_content_with_imgs( - message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(), - b64_imgs=b64_imgs, - img_urls=img_urls, + if supports_image_input: + return HumanMessage( + content=build_content_with_imgs( + message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(), + b64_imgs=b64_imgs, + img_urls=img_urls, + ) + ) + else: + return HumanMessage( + content=IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES.format( + query=query, prompts=prompts + ).strip() ) - ) diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index e4673cba559..67bd084a52a 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -443,26 +443,10 @@ export function AssistantEditor({ let enabledTools = Object.keys(values.enabled_tools_map) .map((toolId) => Number(toolId)) .filter((toolId) => values.enabled_tools_map[toolId]); + const searchToolEnabled = searchTool ? enabledTools.includes(searchTool.id) : false; - const imageGenerationToolEnabled = imageGenerationTool - ? enabledTools.includes(imageGenerationTool.id) - : false; - - if (imageGenerationToolEnabled) { - if ( - // model must support image input for image generation - // to work - !checkLLMSupportsImageInput( - values.llm_model_version_override || defaultModelName || "" - ) - ) { - enabledTools = enabledTools.filter( - (toolId) => toolId !== imageGenerationTool!.id - ); - } - } // if disable_retrieval is set, set num_chunks to 0 // to tell the backend to not fetch any documents @@ -913,25 +897,20 @@ export function AssistantEditor({ id={`enabled_tools_map.${imageGenerationTool.id}`} name={`enabled_tools_map.${imageGenerationTool.id}`} onCheckedChange={() => { - if ( - currentLLMSupportsImageOutput && - isImageGenerationAvailable - ) { + if (isImageGenerationAvailable) { toggleToolInValues( imageGenerationTool.id ); } }} className={ - !currentLLMSupportsImageOutput || !isImageGenerationAvailable ? "opacity-50 cursor-not-allowed" : "" } /> - {(!currentLLMSupportsImageOutput || - !isImageGenerationAvailable) && ( + {!isImageGenerationAvailable && (

{!currentLLMSupportsImageOutput diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index a3ad5097d06..997ca93739e 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -49,6 +49,7 @@ import { useContext, useEffect, useLayoutEffect, + useMemo, useRef, useState, } from "react"; @@ -1627,7 +1628,7 @@ export function ChatPage({ setPopup({ type: "error", message: - "The current Assistant does not support image input. Please select an assistant with Vision support.", + "The current model does not support image input. Please select a model with Vision support.", }); return; } @@ -1845,6 +1846,14 @@ export function ChatPage({ // eslint-disable-next-line react-hooks/exhaustive-deps }, [messageHistory]); + const imageFileInMessageHistory = useMemo(() => { + return messageHistory + .filter((message) => message.type === "user") + .some((message) => + message.files.some((file) => file.type === ChatFileType.IMAGE) + ); + }, [messageHistory]); + const currentVisibleRange = visibleRange.get(currentSessionId()) || { start: 0, end: 0, @@ -1925,6 +1934,10 @@ export function ChatPage({ handleSlackChatRedirect(); }, [searchParams, router]); + useEffect(() => { + llmOverrideManager.updateImageFilesPresent(imageFileInMessageHistory); + }, [imageFileInMessageHistory]); + useEffect(() => { const handleKeyDown = (event: KeyboardEvent) => { if (event.metaKey || event.ctrlKey) { diff --git a/web/src/app/chat/input/LLMPopover.tsx b/web/src/app/chat/input/LLMPopover.tsx index 56f8eb86bd9..ee41eafa4f0 100644 --- a/web/src/app/chat/input/LLMPopover.tsx +++ b/web/src/app/chat/input/LLMPopover.tsx @@ -5,7 +5,6 @@ import { PopoverTrigger, } from "@/components/ui/popover"; import { ChatInputOption } from "./ChatInputOption"; -import { AnthropicSVG } from "@/components/icons/icons"; import { getDisplayNameForModel } from "@/lib/hooks"; import { checkLLMSupportsImageInput, @@ -19,6 +18,14 @@ import { import { Persona } from "@/app/admin/assistants/interfaces"; import { LlmOverrideManager } from "@/lib/hooks"; +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from "@/components/ui/tooltip"; +import { FiAlertTriangle } from "react-icons/fi"; + interface LLMPopoverProps { llmProviders: LLMProviderDescriptor[]; llmOverrideManager: LlmOverrideManager; @@ -139,6 +146,22 @@ export default function LLMPopover({ ); } })()} + {llmOverrideManager.imageFilesPresent && + !checkLLMSupportsImageInput(name) && ( + + + + + + +

+ This LLM is not vision-capable and cannot process + image files present in your chat session. +

+
+ + + )} ); } diff --git a/web/src/components/admin/connectors/Popup.tsx b/web/src/components/admin/connectors/Popup.tsx index 51cb653a029..7502bf802d6 100644 --- a/web/src/components/admin/connectors/Popup.tsx +++ b/web/src/components/admin/connectors/Popup.tsx @@ -1,7 +1,8 @@ import { useRef, useState } from "react"; import { cva, type VariantProps } from "class-variance-authority"; import { cn } from "@/lib/utils"; -import { CheckCircle, XCircle } from "lucide-react"; +import { Check, CheckCircle, XCircle } from "lucide-react"; +import { Warning } from "@phosphor-icons/react"; const popupVariants = cva( "fixed bottom-4 left-4 p-4 rounded-lg shadow-xl text-white z-[10000] flex items-center space-x-3 transition-all duration-300 ease-in-out", { @@ -26,9 +27,9 @@ export interface PopupSpec extends VariantProps { export const Popup: React.FC = ({ message, type }) => (
{type === "success" ? ( - + ) : type === "error" ? ( - + ) : type === "info" ? ( void; updateModelOverrideForChatSession: (chatSession?: ChatSession) => void; + imageFilesPresent: boolean; + updateImageFilesPresent: (present: boolean) => void; } export function useLlmOverride( llmProviders: LLMProviderDescriptor[], @@ -387,6 +389,11 @@ export function useLlmOverride( } return { name: "", provider: "", modelName: "" }; }; + const [imageFilesPresent, setImageFilesPresent] = useState(false); + + const updateImageFilesPresent = (present: boolean) => { + setImageFilesPresent(present); + }; const [globalDefault, setGlobalDefault] = useState( getValidLlmOverride(globalModel) @@ -451,6 +458,8 @@ export function useLlmOverride( setGlobalDefault, temperature, updateTemperature, + imageFilesPresent, + updateImageFilesPresent, }; }