Skip to content

Commit

Permalink
Allow all LLMs for image generation assistants
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx committed Jan 21, 2025
1 parent cc4953b commit 12217f5
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 59 deletions.
15 changes: 13 additions & 2 deletions backend/onyx/chat/prompt_builder/answer_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import check_message_tokens
from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.llm.utils import model_supports_image_input
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(
provider_type=llm_config.model_provider,
model_name=llm_config.model_name,
)
self.llm_config = llm_config
self.llm_tokenizer_encode_func = cast(
Callable[[str], list[int]], llm_tokenizer.encode
)
Expand All @@ -94,12 +96,21 @@ def __init__(
(
self.message_history,
self.history_token_cnts,
) = translate_history_to_basemessages(message_history)
) = translate_history_to_basemessages(
message_history,
exclude_images=not model_supports_image_input(
self.llm_config.model_name,
self.llm_config.model_provider,
),
)

self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
check_message_tokens(
user_message,
self.llm_tokenizer_encode_func,
),
)

self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
Expand Down
10 changes: 8 additions & 2 deletions backend/onyx/chat/prompt_builder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@

def translate_onyx_msg_to_langchain(
msg: ChatMessage | PreviousMessage,
exclude_images: bool = False,
) -> BaseMessage:
files: list[InMemoryChatFile] = []

# If the message is a `ChatMessage`, it doesn't have the downloaded files
# attached. Just ignore them for now.
if not isinstance(msg, ChatMessage):
files = msg.files
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
content = build_content_with_imgs(
msg.message, files, message_type=msg.message_type, exclude_images=exclude_images
)

if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
Expand All @@ -32,9 +35,12 @@ def translate_onyx_msg_to_langchain(

def translate_history_to_basemessages(
history: list[ChatMessage] | list["PreviousMessage"],
exclude_images: bool = False,
) -> tuple[list[BaseMessage], list[int]]:
history_basemessages = [
translate_onyx_msg_to_langchain(msg) for msg in history if msg.token_count != 0
translate_onyx_msg_to_langchain(msg, exclude_images)
for msg in history
if msg.token_count != 0
]
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
return history_basemessages, history_token_counts
52 changes: 36 additions & 16 deletions backend/onyx/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def build_content_with_imgs(
img_urls: list[str] | None = None,
b64_imgs: list[str] | None = None,
message_type: MessageType = MessageType.USER,
exclude_images: bool = False,
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
files = files or []

Expand All @@ -157,7 +158,7 @@ def build_content_with_imgs(

message_main_content = _build_content(message, files)

if not img_files and not img_urls:
if exclude_images or (not img_files and not img_urls):
return message_main_content

return cast(
Expand Down Expand Up @@ -382,9 +383,19 @@ def _strip_colon_from_model_name(model_name: str) -> str:
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name


def _find_model_obj(
model_map: dict, provider: str, model_names: list[str | None]
) -> dict | None:
def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | None:
stripped_model_name = _strip_extra_provider_from_model_name(model_name)

model_names = [
model_name,
_strip_extra_provider_from_model_name(model_name),
# Remove leading extra provider. Usually for cases where user has a
# customer model proxy which appends another prefix
# remove :XXXX from the end, if present. Needed for ollama.
_strip_colon_from_model_name(model_name),
_strip_colon_from_model_name(stripped_model_name),
]

# Filter out None values and deduplicate model names
filtered_model_names = [name for name in model_names if name]

Expand Down Expand Up @@ -417,21 +428,10 @@ def get_llm_max_tokens(
return GEN_AI_MAX_TOKENS

try:
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
model_name
)
model_obj = _find_model_obj(
model_map,
model_provider,
[
model_name,
# Remove leading extra provider. Usually for cases where user has a
# customer model proxy which appends another prefix
extra_provider_stripped_model_name,
# remove :XXXX from the end, if present. Needed for ollama.
_strip_colon_from_model_name(model_name),
_strip_colon_from_model_name(extra_provider_stripped_model_name),
],
model_name,
)
if not model_obj:
raise RuntimeError(
Expand Down Expand Up @@ -523,3 +523,23 @@ def get_max_input_tokens(
raise RuntimeError("No tokens for input for the LLM given settings")

return input_toks


def model_supports_image_input(model_name: str, model_provider: str) -> bool:
model_map = get_model_map()
try:
model_obj = _find_model_obj(
model_map,
model_provider,
model_name,
)
if not model_obj:
raise RuntimeError(
f"No litellm entry found for {model_provider}/{model_name}"
)
return model_obj.get("supports_vision", False)
except Exception:
logger.exception(
f"Failed to get model object for {model_provider}/{model_name}"
)
return False
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import message_to_string
from onyx.llm.utils import model_supports_image_input
from onyx.prompts.constants import GENERAL_SEP_PAT
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolResponse
Expand Down Expand Up @@ -316,12 +317,22 @@ def build_next_prompt(
for img in img_generation_response
if img.image_data is not None
]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=prompt_builder.get_user_message_content(),
img_urls=img_urls,
b64_imgs=b64_imgs,
)

user_prompt = build_image_generation_user_prompt(
query=prompt_builder.get_user_message_content(),
supports_image_input=model_supports_image_input(
prompt_builder.llm_config.model_name,
prompt_builder.llm_config.model_provider,
),
prompts=[
prompt
for response in img_generation_response
for prompt in response.revised_prompt
],
img_urls=img_urls,
b64_imgs=b64_imgs,
)

prompt_builder.update_user_prompt(user_prompt)

return prompt_builder
30 changes: 24 additions & 6 deletions backend/onyx/tools/tool_implementations/images/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,34 @@
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
"""

IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES = """
You have generated images based on the following query: "{query}".
The prompts used to generate these images were: {prompts}
Describe what the generated images depict based on the query and prompts provided.
Summarize the key elements and content of the images in a sentence or two. Be specific
about what was generated rather than speculating about what the images 'likely' contain.
"""


def build_image_generation_user_prompt(
query: str,
supports_image_input: bool,
img_urls: list[str] | None = None,
b64_imgs: list[str] | None = None,
prompts: list[str] | None = None,
) -> HumanMessage:
return HumanMessage(
content=build_content_with_imgs(
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
b64_imgs=b64_imgs,
img_urls=img_urls,
if supports_image_input:
return HumanMessage(
content=build_content_with_imgs(
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
b64_imgs=b64_imgs,
img_urls=img_urls,
)
)
else:
return HumanMessage(
content=IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES.format(
query=query, prompts=prompts
).strip()
)
)
27 changes: 3 additions & 24 deletions web/src/app/admin/assistants/AssistantEditor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -443,26 +443,10 @@ export function AssistantEditor({
let enabledTools = Object.keys(values.enabled_tools_map)
.map((toolId) => Number(toolId))
.filter((toolId) => values.enabled_tools_map[toolId]);

const searchToolEnabled = searchTool
? enabledTools.includes(searchTool.id)
: false;
const imageGenerationToolEnabled = imageGenerationTool
? enabledTools.includes(imageGenerationTool.id)
: false;

if (imageGenerationToolEnabled) {
if (
// model must support image input for image generation
// to work
!checkLLMSupportsImageInput(
values.llm_model_version_override || defaultModelName || ""
)
) {
enabledTools = enabledTools.filter(
(toolId) => toolId !== imageGenerationTool!.id
);
}
}

// if disable_retrieval is set, set num_chunks to 0
// to tell the backend to not fetch any documents
Expand Down Expand Up @@ -913,25 +897,20 @@ export function AssistantEditor({
id={`enabled_tools_map.${imageGenerationTool.id}`}
name={`enabled_tools_map.${imageGenerationTool.id}`}
onCheckedChange={() => {
if (
currentLLMSupportsImageOutput &&
isImageGenerationAvailable
) {
if (isImageGenerationAvailable) {
toggleToolInValues(
imageGenerationTool.id
);
}
}}
className={
!currentLLMSupportsImageOutput ||
!isImageGenerationAvailable
? "opacity-50 cursor-not-allowed"
: ""
}
/>
</TooltipTrigger>
{(!currentLLMSupportsImageOutput ||
!isImageGenerationAvailable) && (
{!isImageGenerationAvailable && (
<TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
{!currentLLMSupportsImageOutput
Expand Down
22 changes: 22 additions & 0 deletions web/src/app/chat/ChatPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import {
useContext,
useEffect,
useLayoutEffect,
useMemo,
useRef,
useState,
} from "react";
Expand Down Expand Up @@ -1845,6 +1846,14 @@ export function ChatPage({
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [messageHistory]);

const imageFileInMessageHistory = useMemo(() => {
return messageHistory
.filter((message) => message.type === "user")
.some((message) =>
message.files.some((file) => file.type === ChatFileType.IMAGE)
);
}, [messageHistory]);

const currentVisibleRange = visibleRange.get(currentSessionId()) || {
start: 0,
end: 0,
Expand Down Expand Up @@ -1925,6 +1934,19 @@ export function ChatPage({
handleSlackChatRedirect();
}, [searchParams, router]);

useEffect(() => {
if (
imageFileInMessageHistory &&
!checkLLMSupportsImageInput(llmOverrideManager.llmOverride.modelName)
) {
setPopup({
message:
"This LLM will not be able to process all files (i.e. image files) in your chat history",
type: "error",
});
}
}, [llmOverrideManager.llmOverride, imageFileInMessageHistory]);

useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (event.metaKey || event.ctrlKey) {
Expand Down
7 changes: 4 additions & 3 deletions web/src/components/admin/connectors/Popup.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { useRef, useState } from "react";
import { cva, type VariantProps } from "class-variance-authority";
import { cn } from "@/lib/utils";
import { CheckCircle, XCircle } from "lucide-react";
import { Check, CheckCircle, XCircle } from "lucide-react";
import { Warning } from "@phosphor-icons/react";
const popupVariants = cva(
"fixed bottom-4 left-4 p-4 rounded-lg shadow-xl text-white z-[10000] flex items-center space-x-3 transition-all duration-300 ease-in-out",
{
Expand All @@ -26,9 +27,9 @@ export interface PopupSpec extends VariantProps<typeof popupVariants> {
export const Popup: React.FC<PopupSpec> = ({ message, type }) => (
<div className={cn(popupVariants({ type }))}>
{type === "success" ? (
<CheckCircle className="w-6 h-6 animate-pulse" />
<Check className="w-6 h-6" />
) : type === "error" ? (
<XCircle className="w-6 h-6 animate-pulse" />
<Warning className="w-6 h-6 " />
) : type === "info" ? (
<svg
className="w-6 h-6"
Expand Down

0 comments on commit 12217f5

Please sign in to comment.