Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow all LLMs for image generation assistants #3730

Merged
merged 5 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions backend/onyx/chat/prompt_builder/answer_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import check_message_tokens
from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.llm.utils import model_supports_image_input
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(
provider_type=llm_config.model_provider,
model_name=llm_config.model_name,
)
self.llm_config = llm_config
self.llm_tokenizer_encode_func = cast(
Callable[[str], list[int]], llm_tokenizer.encode
)
Expand All @@ -94,12 +96,21 @@ def __init__(
(
self.message_history,
self.history_token_cnts,
) = translate_history_to_basemessages(message_history)
) = translate_history_to_basemessages(
message_history,
exclude_images=not model_supports_image_input(
self.llm_config.model_name,
self.llm_config.model_provider,
),
)

self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
check_message_tokens(
user_message,
self.llm_tokenizer_encode_func,
),
)

self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
Expand Down
10 changes: 8 additions & 2 deletions backend/onyx/chat/prompt_builder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@

def translate_onyx_msg_to_langchain(
msg: ChatMessage | PreviousMessage,
exclude_images: bool = False,
) -> BaseMessage:
files: list[InMemoryChatFile] = []

# If the message is a `ChatMessage`, it doesn't have the downloaded files
# attached. Just ignore them for now.
if not isinstance(msg, ChatMessage):
files = msg.files
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
content = build_content_with_imgs(
msg.message, files, message_type=msg.message_type, exclude_images=exclude_images
)

if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
Expand All @@ -32,9 +35,12 @@ def translate_onyx_msg_to_langchain(

def translate_history_to_basemessages(
history: list[ChatMessage] | list["PreviousMessage"],
exclude_images: bool = False,
) -> tuple[list[BaseMessage], list[int]]:
history_basemessages = [
translate_onyx_msg_to_langchain(msg) for msg in history if msg.token_count != 0
translate_onyx_msg_to_langchain(msg, exclude_images)
for msg in history
if msg.token_count != 0
]
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
return history_basemessages, history_token_counts
52 changes: 36 additions & 16 deletions backend/onyx/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def build_content_with_imgs(
img_urls: list[str] | None = None,
b64_imgs: list[str] | None = None,
message_type: MessageType = MessageType.USER,
exclude_images: bool = False,
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
files = files or []

Expand All @@ -157,7 +158,7 @@ def build_content_with_imgs(

message_main_content = _build_content(message, files)

if not img_files and not img_urls:
if exclude_images or (not img_files and not img_urls):
return message_main_content

return cast(
Expand Down Expand Up @@ -382,9 +383,19 @@ def _strip_colon_from_model_name(model_name: str) -> str:
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name


def _find_model_obj(
model_map: dict, provider: str, model_names: list[str | None]
) -> dict | None:
def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | None:
stripped_model_name = _strip_extra_provider_from_model_name(model_name)

model_names = [
model_name,
_strip_extra_provider_from_model_name(model_name),
# Remove leading extra provider. Usually for cases where user has a
# customer model proxy which appends another prefix
# remove :XXXX from the end, if present. Needed for ollama.
_strip_colon_from_model_name(model_name),
_strip_colon_from_model_name(stripped_model_name),
]

# Filter out None values and deduplicate model names
filtered_model_names = [name for name in model_names if name]

Expand Down Expand Up @@ -417,21 +428,10 @@ def get_llm_max_tokens(
return GEN_AI_MAX_TOKENS

try:
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
model_name
)
model_obj = _find_model_obj(
model_map,
model_provider,
[
model_name,
# Remove leading extra provider. Usually for cases where user has a
# customer model proxy which appends another prefix
extra_provider_stripped_model_name,
# remove :XXXX from the end, if present. Needed for ollama.
_strip_colon_from_model_name(model_name),
_strip_colon_from_model_name(extra_provider_stripped_model_name),
],
model_name,
)
if not model_obj:
raise RuntimeError(
Expand Down Expand Up @@ -523,3 +523,23 @@ def get_max_input_tokens(
raise RuntimeError("No tokens for input for the LLM given settings")

return input_toks


def model_supports_image_input(model_name: str, model_provider: str) -> bool:
model_map = get_model_map()
try:
model_obj = _find_model_obj(
model_map,
model_provider,
model_name,
)
if not model_obj:
raise RuntimeError(
f"No litellm entry found for {model_provider}/{model_name}"
)
return model_obj.get("supports_vision", False)
except Exception:
logger.exception(
f"Failed to get model object for {model_provider}/{model_name}"
)
return False
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import message_to_string
from onyx.llm.utils import model_supports_image_input
from onyx.prompts.constants import GENERAL_SEP_PAT
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolResponse
Expand Down Expand Up @@ -316,12 +317,22 @@ def build_next_prompt(
for img in img_generation_response
if img.image_data is not None
]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=prompt_builder.get_user_message_content(),
img_urls=img_urls,
b64_imgs=b64_imgs,
)

user_prompt = build_image_generation_user_prompt(
query=prompt_builder.get_user_message_content(),
supports_image_input=model_supports_image_input(
prompt_builder.llm_config.model_name,
prompt_builder.llm_config.model_provider,
),
prompts=[
prompt
for response in img_generation_response
for prompt in response.revised_prompt
],
img_urls=img_urls,
b64_imgs=b64_imgs,
)

prompt_builder.update_user_prompt(user_prompt)

return prompt_builder
30 changes: 24 additions & 6 deletions backend/onyx/tools/tool_implementations/images/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,34 @@
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
"""

IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES = """
You have generated images based on the following query: "{query}".
The prompts used to generate these images were: {prompts}

Describe what the generated images depict based on the query and prompts provided.
Summarize the key elements and content of the images in a sentence or two. Be specific
about what was generated rather than speculating about what the images 'likely' contain.
"""


def build_image_generation_user_prompt(
query: str,
supports_image_input: bool,
img_urls: list[str] | None = None,
b64_imgs: list[str] | None = None,
prompts: list[str] | None = None,
) -> HumanMessage:
return HumanMessage(
content=build_content_with_imgs(
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
b64_imgs=b64_imgs,
img_urls=img_urls,
if supports_image_input:
return HumanMessage(
content=build_content_with_imgs(
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
b64_imgs=b64_imgs,
img_urls=img_urls,
)
)
else:
return HumanMessage(
content=IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES.format(
query=query, prompts=prompts
).strip()
)
)
27 changes: 3 additions & 24 deletions web/src/app/admin/assistants/AssistantEditor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -443,26 +443,10 @@ export function AssistantEditor({
let enabledTools = Object.keys(values.enabled_tools_map)
.map((toolId) => Number(toolId))
.filter((toolId) => values.enabled_tools_map[toolId]);

const searchToolEnabled = searchTool
? enabledTools.includes(searchTool.id)
: false;
const imageGenerationToolEnabled = imageGenerationTool
? enabledTools.includes(imageGenerationTool.id)
: false;

if (imageGenerationToolEnabled) {
if (
// model must support image input for image generation
// to work
!checkLLMSupportsImageInput(
values.llm_model_version_override || defaultModelName || ""
)
) {
enabledTools = enabledTools.filter(
(toolId) => toolId !== imageGenerationTool!.id
);
}
}

// if disable_retrieval is set, set num_chunks to 0
// to tell the backend to not fetch any documents
Expand Down Expand Up @@ -913,25 +897,20 @@ export function AssistantEditor({
id={`enabled_tools_map.${imageGenerationTool.id}`}
name={`enabled_tools_map.${imageGenerationTool.id}`}
onCheckedChange={() => {
if (
currentLLMSupportsImageOutput &&
isImageGenerationAvailable
) {
if (isImageGenerationAvailable) {
toggleToolInValues(
imageGenerationTool.id
);
}
}}
className={
!currentLLMSupportsImageOutput ||
!isImageGenerationAvailable
? "opacity-50 cursor-not-allowed"
: ""
}
/>
</TooltipTrigger>
{(!currentLLMSupportsImageOutput ||
!isImageGenerationAvailable) && (
{!isImageGenerationAvailable && (
<TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
{!currentLLMSupportsImageOutput
Expand Down
17 changes: 16 additions & 1 deletion web/src/app/chat/ChatPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import {
useContext,
useEffect,
useLayoutEffect,
useMemo,
useRef,
useState,
} from "react";
Expand Down Expand Up @@ -1627,7 +1628,7 @@ export function ChatPage({
setPopup({
type: "error",
message:
"The current Assistant does not support image input. Please select an assistant with Vision support.",
"The current model does not support image input. Please select a model with Vision support.",
});
return;
}
Expand Down Expand Up @@ -1845,6 +1846,14 @@ export function ChatPage({
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [messageHistory]);

const imageFileInMessageHistory = useMemo(() => {
return messageHistory
.filter((message) => message.type === "user")
.some((message) =>
message.files.some((file) => file.type === ChatFileType.IMAGE)
);
}, [messageHistory]);

const currentVisibleRange = visibleRange.get(currentSessionId()) || {
start: 0,
end: 0,
Expand Down Expand Up @@ -1925,6 +1934,12 @@ export function ChatPage({
handleSlackChatRedirect();
}, [searchParams, router]);

useEffect(() => {
if (imageFileInMessageHistory) {
llmOverrideManager.updateImageFilesPresent(true);
}
}, [imageFileInMessageHistory]);

useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (event.metaKey || event.ctrlKey) {
Expand Down
25 changes: 24 additions & 1 deletion web/src/app/chat/input/LLMPopover.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import {
PopoverTrigger,
} from "@/components/ui/popover";
import { ChatInputOption } from "./ChatInputOption";
import { AnthropicSVG } from "@/components/icons/icons";
import { getDisplayNameForModel } from "@/lib/hooks";
import {
checkLLMSupportsImageInput,
Expand All @@ -19,6 +18,14 @@ import {
import { Persona } from "@/app/admin/assistants/interfaces";
import { LlmOverrideManager } from "@/lib/hooks";

import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { FiAlertTriangle } from "react-icons/fi";

interface LLMPopoverProps {
llmProviders: LLMProviderDescriptor[];
llmOverrideManager: LlmOverrideManager;
Expand Down Expand Up @@ -139,6 +146,22 @@ export default function LLMPopover({
);
}
})()}
{llmOverrideManager.imageFilesPresent &&
!checkLLMSupportsImageInput(name) && (
<TooltipProvider>
<Tooltip delayDuration={0}>
<TooltipTrigger className="my-auto flex items-center ml-auto">
<FiAlertTriangle className="text-alert" size={16} />
</TooltipTrigger>
<TooltipContent>
<p className="text-xs">
This LLM is not vision-capable and cannot process
image files present in your chat session.
</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
)}
</button>
);
}
Expand Down
Loading
Loading