From 423db484e5972a5abe03a02b4469193a1f546fc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9-Abush=20Clause?= Date: Sun, 26 Jan 2025 23:31:44 +0100 Subject: [PATCH] feat(provider_engine): add support for uploading text and PDF files to Anthropic Fixes #485 --- basilisk/conversation/__init__.py | 14 +- .../{image_model.py => attached_file.py} | 159 +++++++++++------- basilisk/conversation/conversation_helper.py | 10 +- basilisk/conversation/conversation_model.py | 4 +- basilisk/gui/conversation_tab.py | 97 +++++++---- basilisk/gui/main_frame.py | 2 +- basilisk/provider_engine/anthropic_engine.py | 28 ++- basilisk/provider_engine/gemini_engine.py | 4 +- 8 files changed, 205 insertions(+), 113 deletions(-) rename basilisk/conversation/{image_model.py => attached_file.py} (82%) diff --git a/basilisk/conversation/__init__.py b/basilisk/conversation/__init__.py index 88073e25..0d1a93a4 100644 --- a/basilisk/conversation/__init__.py +++ b/basilisk/conversation/__init__.py @@ -1,3 +1,12 @@ +from .attached_file import ( + URL_PATTERN, + AttachmentFile, + AttachmentFileTypes, + ImageFile, + NotImageError, + get_mime_type, + parse_supported_attachment_formats, +) from .conversation_helper import PROMPT_TITLE from .conversation_model import ( Conversation, @@ -5,16 +14,19 @@ MessageBlock, MessageRoleEnum, ) -from .image_model import URL_PATTERN, ImageFile, ImageFileTypes, NotImageError __all__ = [ + "AttachmentFile", + "AttachmentFileTypes", "Conversation", + "get_mime_type", "ImageFile", "ImageFileTypes", "Message", "MessageBlock", "MessageRoleEnum", "NotImageError", + "parse_supported_attachment_formats", "PROMPT_TITLE", "URL_PATTERN", ] diff --git a/basilisk/conversation/image_model.py b/basilisk/conversation/attached_file.py similarity index 82% rename from basilisk/conversation/image_model.py rename to basilisk/conversation/attached_file.py index 82c0bf47..b16bd82f 100644 --- a/basilisk/conversation/image_model.py +++ b/basilisk/conversation/attached_file.py @@ -105,20 +105,25 @@ def parse_supported_attachment_formats( return wildcard -class ImageFileTypes(Enum): +def get_mime_type(path: str) -> str | None: + """ + Get the MIME type of a file. + """ + return mimetypes.guess_type(path)[0] + + +class AttachmentFileTypes(Enum): UNKNOWN = "unknown" - IMAGE_LOCAL = "local" - IMAGE_MEMORY = "memory" - IMAGE_URL = "http" + LOCAL = "local" + MEMORY = "memory" + URL = "http" @classmethod - def _missing_(cls, value: object) -> ImageFileTypes: - if isinstance(value, str) and value.lower() == "data": - return cls.IMAGE_URL - if isinstance(value, str) and value.lower() == "https": - return cls.IMAGE_URL + def _missing_(cls, value: object) -> AttachmentFileTypes: + if isinstance(value, str) and value.lower() in ("data", "https"): + return cls.URL if isinstance(value, str) and value.lower() == "zip": - return cls.IMAGE_LOCAL + return cls.LOCAL return cls.UNKNOWN @@ -126,11 +131,84 @@ class NotImageError(ValueError): pass -class ImageFile(BaseModel): +class AttachmentFile(BaseModel): location: PydanticUPath name: str | None = None description: str | None = None size: int | None = None + + def __init__(self, /, **data: Any) -> None: + super().__init__(**data) + if not self.name: + self.name = self._get_name() + self.size = self._get_size() + + @property + def type(self) -> AttachmentFileTypes: + return AttachmentFileTypes(self.location.protocol) + + def _get_name(self) -> str: + return self.location.name + + def _get_size(self) -> int | None: + if self.type == AttachmentFileTypes.URL: + return None + return self.location.stat().st_size + + @property + def display_size(self) -> str: + size = self.size + if size is None: + return _("Unknown") + if size < 1024: + return f"{size} B" + if size < 1024 * 1024: + return f"{size / 1024:.2f} KB" + return f"{size / 1024 / 1024:.2f} MB" + + @property + def send_location(self) -> UPath: + return self.location + + @property + def mime_type(self) -> str | None: + if self.type == AttachmentFileTypes.URL: + return None + mime_type, _ = mimetypes.guess_type(self.send_location) + return mime_type + + @property + def display_location(self): + location = str(self.location) + if location.startswith("data:"): + location = f"{location[:50]}...{location[-10:]}" + return location + + @staticmethod + def remove_location(location: UPath): + log.debug(f"Removing image at {location}") + try: + fs = location.fs + fs.rm(location.path) + except Exception as e: + log.error(f"Error deleting image at {location}: {e}") + + def read_as_str(self): + with self.location.open(mode="r") as file: + return file.read() + + def encode_base64(self) -> str: + with self.location.open(mode="rb") as file: + return base64.b64encode(file.read()).decode("utf-8") + + def __del__(self): + if self.type == AttachmentFileTypes.URL: + return + if self.type == AttachmentFileTypes.MEMORY: + self.remove_location(self.location) + + +class ImageFile(AttachmentFile): dimensions: tuple[int, int] | None = None resize_location: PydanticUPath | None = Field(default=None, exclude=True) @@ -150,7 +228,7 @@ def build_from_url(cls, url: str) -> ImageFile: dimensions = get_image_dimensions(BytesIO(r.content)) return cls( location=url, - type=ImageFileTypes.IMAGE_URL, + type=AttachmentFileTypes.URL, size=size, description=content_type, dimensions=dimensions, @@ -188,39 +266,17 @@ def validate_location( def __init__(self, /, **data: Any) -> None: super().__init__(**data) - if not self.name: - self.name = self._get_name() - self.size = self._get_size() if not self.dimensions: self.dimensions = self._get_dimensions() __init__.__pydantic_base_init__ = True @property - def type(self) -> ImageFileTypes: - return ImageFileTypes(self.location.protocol) - - def _get_name(self) -> str: - return self.location.name - - def _get_size(self) -> int | None: - if self.type == ImageFileTypes.IMAGE_URL: - return None - return self.location.stat().st_size - - @property - def display_size(self) -> str: - size = self.size - if size is None: - return _("Unknown") - if size < 1024: - return f"{size} B" - if size < 1024 * 1024: - return f"{size / 1024:.2f} KB" - return f"{size / 1024 / 1024:.2f} MB" + def send_location(self) -> UPath: + return self.resize_location or self.location def _get_dimensions(self) -> tuple[int, int] | None: - if self.type == ImageFileTypes.IMAGE_URL: + if self.type == AttachmentFileTypes.URL: return None with self.location.open(mode="rb") as image_file: return get_image_dimensions(image_file) @@ -235,7 +291,7 @@ def display_dimensions(self) -> str: def resize( self, conv_folder: UPath, max_width: int, max_height: int, quality: int ): - if ImageFileTypes.IMAGE_URL == self.type: + if AttachmentFileTypes.URL == self.type: return log.debug("Resizing image") resize_location = conv_folder.joinpath( @@ -253,10 +309,6 @@ def resize( ) self.resize_location = resize_location if success else None - @property - def send_location(self) -> UPath: - return self.resize_location or self.location - @measure_time def encode_image(self) -> str: if self.size and self.size > 1024 * 1024 * 1024: @@ -266,29 +318,15 @@ def encode_image(self) -> str: with self.send_location.open(mode="rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") - @property - def mime_type(self) -> str | None: - if self.type == ImageFileTypes.IMAGE_URL: - return None - mime_type, _ = mimetypes.guess_type(self.send_location) - return mime_type - @property def url(self) -> str: - if not isinstance(self.type, ImageFileTypes): + if not isinstance(self.type, AttachmentFileTypes): raise ValueError("Invalid image type") - if self.type == ImageFileTypes.IMAGE_URL: + if self.type == AttachmentFileTypes.URL: return str(self.location) base64_image = self.encode_image() return f"data:{self.mime_type};base64,{base64_image}" - @property - def display_location(self): - location = str(self.location) - if location.startswith("data:image/"): - location = f"{location[:50]}...{location[-10:]}" - return location - @staticmethod def remove_location(location: UPath): log.debug(f"Removing image at {location}") @@ -299,9 +337,6 @@ def remove_location(location: UPath): log.error(f"Error deleting image at {location}: {e}") def __del__(self): - if self.type == ImageFileTypes.IMAGE_URL: - return if self.resize_location: self.remove_location(self.resize_location) - if self.type == ImageFileTypes.IMAGE_MEMORY: - self.remove_location(self.location) + super().__del__() diff --git a/basilisk/conversation/conversation_helper.py b/basilisk/conversation/conversation_helper.py index 175384cf..862b82ee 100644 --- a/basilisk/conversation/conversation_helper.py +++ b/basilisk/conversation/conversation_helper.py @@ -11,7 +11,7 @@ from basilisk.config import conf from basilisk.decorators import measure_time -from .image_model import ImageFile, ImageFileTypes +from .attached_file import AttachmentFile, AttachmentFileTypes, ImageFile if TYPE_CHECKING: from .conversation_model import Conversation @@ -23,11 +23,13 @@ def save_attachments( - attachments: list[ImageFile], attachment_path: str, fs: ZipFileSystem + attachments: list[AttachmentFile | ImageFile], + attachment_path: str, + fs: ZipFileSystem, ): attachment_mapping = {} for attachment in attachments: - if attachment.type == ImageFileTypes.IMAGE_URL: + if attachment.type == AttachmentFileTypes.URL: continue new_location = f"{attachment_path}/{attachment.location.name}" with attachment.location.open(mode="rb") as attachment_file: @@ -56,7 +58,7 @@ def create_conv_main_file(conversation: Conversation, fs: ZipFileSystem): def restore_attachments(attachments: list[ImageFile], storage_path: UPath): for attachment in attachments: - if attachment.type == ImageFileTypes.IMAGE_URL: + if attachment.type == AttachmentFileTypes.URL: continue new_path = storage_path / attachment.location.name with attachment.location.open(mode="rb") as attachment_file: diff --git a/basilisk/conversation/conversation_model.py b/basilisk/conversation/conversation_model.py index 77c3bcf2..ca7b0447 100644 --- a/basilisk/conversation/conversation_model.py +++ b/basilisk/conversation/conversation_model.py @@ -8,8 +8,8 @@ from basilisk.provider_ai_model import AIModelInfo +from .attached_file import AttachmentFile, ImageFile from .conversation_helper import create_bskc_file, open_bskc_file -from .image_model import ImageFile class MessageRoleEnum(Enum): @@ -21,7 +21,7 @@ class MessageRoleEnum(Enum): class Message(BaseModel): role: MessageRoleEnum content: str - attachments: list[ImageFile] | None = Field(default=None) + attachments: list[AttachmentFile | ImageFile] | None = Field(default=None) class MessageBlock(BaseModel): diff --git a/basilisk/gui/conversation_tab.py b/basilisk/gui/conversation_tab.py index 4122d252..1c414592 100644 --- a/basilisk/gui/conversation_tab.py +++ b/basilisk/gui/conversation_tab.py @@ -20,12 +20,14 @@ from basilisk.conversation import ( PROMPT_TITLE, URL_PATTERN, + AttachmentFile, Conversation, ImageFile, Message, MessageBlock, MessageRoleEnum, NotImageError, + get_mime_type, parse_supported_attachment_formats, ) from basilisk.decorators import ensure_no_task_running @@ -40,6 +42,7 @@ from .base_conversation import BaseConversation from .html_view_window import show_html_view_window +from .read_only_message_dialog import ReadOnlyMessageDialog from .search_dialog import SearchDialog, SearchDirection if TYPE_CHECKING: @@ -101,7 +104,7 @@ def __init__( self.bskc_path = bskc_path self.conv_storage_path = conv_storage_path or self.conv_storage_path() self.conversation = conversation or Conversation() - self.image_files: list[ImageFile] = [] + self.attachment_files: list[AttachmentFile | ImageFile] = [] self.last_time = 0 self.message_segment_manager = MessageSegmentManager() self.recording_thread: Optional[RecordingThread] = None @@ -177,12 +180,8 @@ def init_ui(self): ) self.attachments_list.InsertColumn(0, _("Name")) self.attachments_list.InsertColumn(1, _("Size")) - self.attachments_list.InsertColumn(2, _("Dimensions")) - self.attachments_list.InsertColumn(3, _("Path")) self.attachments_list.SetColumnWidth(0, 200) self.attachments_list.SetColumnWidth(1, 100) - self.attachments_list.SetColumnWidth(2, 100) - self.attachments_list.SetColumnWidth(3, 200) sizer.Add(self.attachments_list, proportion=0, flag=wx.ALL | wx.EXPAND) label = self.create_model_widget() sizer.Add(label, proportion=0, flag=wx.EXPAND) @@ -278,6 +277,10 @@ def on_attachments_context_menu(self, event: wx.ContextMenuEvent): menu = wx.Menu() if selected != wx.NOT_FOUND: + item = wx.MenuItem(menu, wx.ID_ANY, _("Show details") + " Enter") + menu.Append(item) + self.Bind(wx.EVT_MENU, self.on_show_attachment_details, item) + item = wx.MenuItem( menu, wx.ID_ANY, _("Remove selected image") + " (Shift+Del)" ) @@ -293,7 +296,7 @@ def on_attachments_context_menu(self, event: wx.ContextMenuEvent): menu, wx.ID_ANY, _("Paste (image or text)") + " (Ctrl+V)" ) menu.Append(item) - self.Bind(wx.EVT_MENU, self.on_image_paste, item) + self.Bind(wx.EVT_MENU, self.on_attachments_paste, item) item = wx.MenuItem(menu, wx.ID_ANY, _("Add image files...")) menu.Append(item) self.Bind(wx.EVT_MENU, self.add_attachments_dlg, item) @@ -311,19 +314,19 @@ def on_attachments_key_down(self, event: wx.KeyEvent): if modifiers == wx.MOD_CONTROL and key_code == ord("C"): self.on_copy_image_url(None) if modifiers == wx.MOD_CONTROL and key_code == ord("V"): - self.on_image_paste(None) + self.on_attachments_paste(None) if modifiers == wx.MOD_NONE and key_code == wx.WXK_DELETE: self.on_attachments_remove(None) event.Skip() - def on_image_paste(self, event: wx.CommandEvent): + def on_attachments_paste(self, event: wx.CommandEvent): with wx.TheClipboard as clipboard: if clipboard.IsSupported(wx.DataFormat(wx.DF_FILENAME)): log.debug("Pasting files from clipboard") file_data = wx.FileDataObject() clipboard.GetData(file_data) paths = file_data.GetFilenames() - self.add_attachment(paths) + self.add_attachments(paths) elif clipboard.IsSupported(wx.DataFormat(wx.DF_TEXT)): log.debug("Pasting text from clipboard") text_data = wx.TextDataObject() @@ -350,16 +353,16 @@ def on_image_paste(self, event: wx.CommandEvent): ) with path.open("wb") as f: img.SaveFile(f, wx.BITMAP_TYPE_PNG) - self.add_attachment([ImageFile(location=path)]) + self.add_attachments([ImageFile(location=path)]) else: log.info("Unsupported clipboard data") def add_attachments_dlg(self, event: wx.CommandEvent = None): - wilrdcard = parse_supported_attachment_formats( + wildcard = parse_supported_attachment_formats( self.current_engine.supported_attachment_formats ) - if not wilrdcard: + if not wildcard: wx.MessageBox( # Translators: This message is displayed when there are no supported attachment formats. _("This provider does not support any attachment formats."), @@ -367,17 +370,17 @@ def add_attachments_dlg(self, event: wx.CommandEvent = None): wx.OK | wx.ICON_ERROR, ) return - wilrdcard = _("All supported formats") + f" ({wilrdcard})|{wilrdcard}" + wildcard = _("All supported formats") + f" ({wildcard})|{wildcard}" file_dialog = wx.FileDialog( self, message=_("Select one or more files to attach"), style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST | wx.FD_MULTIPLE, - wildcard=wilrdcard, + wildcard=wildcard, ) if file_dialog.ShowModal() == wx.ID_OK: paths = file_dialog.GetPaths() - self.add_attachment(paths) + self.add_attachments(paths) file_dialog.Destroy() def add_image_url_dlg(self, event: wx.CommandEvent = None): @@ -442,7 +445,7 @@ def add_image_from_url(self, url: str): wx.OK | wx.ICON_ERROR, ) return - wx.CallAfter(self.add_attachment, [image_file]) + wx.CallAfter(self.add_attachments, [image_file]) self.task = None @@ -453,11 +456,33 @@ def add_image_url_thread(self, url: str): ) self.task.start() + def on_show_attachment_details(self, event: wx.CommandEvent): + selected = self.attachments_list.GetFirstSelected() + if selected == wx.NOT_FOUND: + return + image_file = self.attachment_files[selected] + details = { + _("Name"): image_file.name, + _("Size"): image_file.display_size, + _("Location"): image_file.location, + } + mime_type = image_file.mime_type + if mime_type: + details[_("MIME type")] = mime_type + if mime_type.startswith("image/"): + details[_("Dimensions")] = image_file.display_dimensions + details_str = "\n".join( + _("%s: %s") % (k, v) for k, v in details.items() + ) + ReadOnlyMessageDialog( + self, _("Attachment details"), details_str + ).ShowModal() + def on_attachments_remove(self, vent: wx.CommandEvent): selection = self.attachments_list.GetFirstSelected() if selection == wx.NOT_FOUND: return - self.image_files.pop(selection) + self.attachment_files.pop(selection) self.refresh_attachments_list() if selection >= self.attachments_list.GetItemCount(): selection -= 1 @@ -472,7 +497,7 @@ def on_copy_image_url(self, event: wx.CommandEvent): selected = self.attachments_list.GetFirstSelected() if selected == wx.NOT_FOUND: return - url = self.image_files[selected].location + url = self.attachment_files[selected].location with wx.TheClipboard as clipboard: clipboard.SetData(wx.TextDataObject(url)) @@ -495,7 +520,7 @@ def refresh_accounts(self): def refresh_attachments_list(self): self.attachments_list.DeleteAllItems() - if not self.image_files: + if not self.attachment_files: self.attachments_list_label.Hide() self.attachments_list.Hide() self.Layout() @@ -503,24 +528,26 @@ def refresh_attachments_list(self): self.attachments_list_label.Show() self.attachments_list.Show() self.Layout() - for i, image in enumerate(self.image_files): + for i, image in enumerate(self.attachment_files): self.attachments_list.InsertItem(i, image.name) self.attachments_list.SetItem(i, 1, image.display_size) - self.attachments_list.SetItem(i, 2, image.display_dimensions) - self.attachments_list.SetItem(i, 3, image.display_location) self.attachments_list.SetItemState( i, wx.LIST_STATE_FOCUSED, wx.LIST_STATE_FOCUSED ) self.attachments_list.EnsureVisible(i) - def add_attachment(self, paths: list[str | ImageFile]): + def add_attachments(self, paths: list[str | AttachmentFile | ImageFile]): log.debug(f"Adding images: {paths}") for path in paths: - if isinstance(path, ImageFile): - self.image_files.append(path) + if isinstance(path, (AttachmentFile, ImageFile)): + self.attachment_files.append(path) else: - file = ImageFile(location=path) - self.image_files.append(file) + mime_type = get_mime_type(path) + if mime_type.startswith("image/"): + file = ImageFile(location=path) + else: + file = AttachmentFile(location=path) + self.attachment_files.append(file) self.refresh_attachments_list() self.attachments_list.SetFocus() @@ -889,7 +916,7 @@ def on_prompt_key_down(self, event: wx.KeyEvent): event.Skip() def on_prompt_paste(self, event): - self.on_image_paste(event) + self.on_attachments_paste(event) def insert_previous_prompt(self, event: wx.CommandEvent = None): if self.conversation.messages: @@ -978,7 +1005,7 @@ def refresh_messages(self, need_clear: bool = True): if need_clear: self.messages.Clear() self.message_segment_manager.clear() - self.image_files.clear() + self.attachment_files.clear() self.refresh_attachments_list() for block in self.conversation.messages: self.display_new_block(block) @@ -1084,7 +1111,7 @@ def ensure_model_compatibility(self) -> ProviderAIModel | None: _("Please select a model"), _("Error"), wx.OK | wx.ICON_ERROR ) return None - if self.image_files and not model.vision: + if self.attachment_files and not model.vision: vision_models = ", ".join( [m.name or m.id for m in self.current_engine.models if m.vision] ) @@ -1109,7 +1136,7 @@ def get_new_message_block(self) -> MessageBlock | None: if not model: return None if config.conf().images.resize: - for image in self.image_files: + for image in self.attachment_files: image.resize( self.conv_storage_path, config.conf().images.max_width, @@ -1120,7 +1147,7 @@ def get_new_message_block(self) -> MessageBlock | None: request=Message( role=MessageRoleEnum.USER, content=self.prompt.GetValue(), - attachments=self.image_files, + attachments=self.attachment_files, ), model_id=model.id, provider_id=self.current_account.provider.id, @@ -1146,7 +1173,7 @@ def get_completion_args(self) -> dict[str, Any] | None: def on_submit(self, event: wx.CommandEvent): if not self.submit_btn.IsEnabled(): return - if not self.prompt.GetValue() and not self.image_files: + if not self.prompt.GetValue() and not self.attachment_files: self.prompt.SetFocus() return completion_kw = self.get_completion_args() @@ -1206,7 +1233,7 @@ def _pre_handle_completion_with_stream(self, new_block: MessageBlock): self.display_new_block(new_block) self.messages.SetInsertionPointEnd() self.prompt.Clear() - self.image_files.clear() + self.attachment_files.clear() self.refresh_attachments_list() def _handle_completion_with_stream(self, chunk: str): @@ -1315,7 +1342,7 @@ def _post_completion_without_stream(self, new_block: MessageBlock): self.display_new_block(new_block) self._handle_accessible_output(new_block.response.content) self.prompt.Clear() - self.image_files.clear() + self.attachment_files.clear() self.refresh_attachments_list() if config.conf().conversation.focus_history_after_send: self.messages.SetFocus() diff --git a/basilisk/gui/main_frame.py b/basilisk/gui/main_frame.py index 95cb8b5b..f0e376a3 100644 --- a/basilisk/gui/main_frame.py +++ b/basilisk/gui/main_frame.py @@ -298,7 +298,7 @@ def screen_capture( def post_screen_capture(self, imagefile: ImageFile | str): log.debug("Screen capture received") - self.current_tab.add_attachment([imagefile]) + self.current_tab.add_attachments([imagefile]) if not self.IsShown(): self.Show() self.Restore() diff --git a/basilisk/provider_engine/anthropic_engine.py b/basilisk/provider_engine/anthropic_engine.py index 155050f8..46d7592c 100644 --- a/basilisk/provider_engine/anthropic_engine.py +++ b/basilisk/provider_engine/anthropic_engine.py @@ -7,11 +7,13 @@ from anthropic import Anthropic from anthropic.types import Message as AnthropicMessage from anthropic.types import TextBlock +from anthropic.types.document_block_param import DocumentBlockParam from anthropic.types.image_block_param import ImageBlockParam, Source +from anthropic.types.text_block_param import TextBlockParam from basilisk.conversation import ( + AttachmentFileTypes, Conversation, - ImageFileTypes, Message, MessageBlock, MessageRoleEnum, @@ -168,15 +170,29 @@ def convert_message(self, message: Message) -> dict: contents = [TextBlock(text=message.content, type="text")] if message.attachments: for attachment in message.attachments: - if attachment.type != ImageFileTypes.IMAGE_URL: + mime_type = attachment.mime_type + if attachment.type != AttachmentFileTypes.URL: source = Source( - data=attachment.encode_image(), + data=None, media_type=attachment.mime_type, type="base64", ) - contents.append( - ImageBlockParam(source=source, type="image") - ) + if mime_type.startswith("image/"): + source["data"] = attachment.encode_image() + contents.append( + ImageBlockParam(source=source, type="image") + ) + elif mime_type.startswith("application/"): + source["data"] = attachment.encode_base64() + contents.append( + DocumentBlockParam(source=source, type="document") + ) + elif mime_type in ("text/csv", "text/plain"): + source["data"] = attachment.read_as_str() + source["type"] = "text" + contents.append( + TextBlockParam(source=source, type="document") + ) return {"role": message.role.value, "content": contents} prepare_message_request = convert_message diff --git a/basilisk/provider_engine/gemini_engine.py b/basilisk/provider_engine/gemini_engine.py index 8e92d8af..c085d7ed 100644 --- a/basilisk/provider_engine/gemini_engine.py +++ b/basilisk/provider_engine/gemini_engine.py @@ -7,9 +7,9 @@ import google.generativeai as genai from basilisk.conversation import ( + AttachmentFileTypes, Conversation, ImageFile, - ImageFileTypes, Message, MessageBlock, MessageRoleEnum, @@ -122,7 +122,7 @@ def convert_role(self, role: MessageRoleEnum) -> str: ) def convert_image(self, image: ImageFile) -> genai.protos.Part: - if image.type == ImageFileTypes.IMAGE_URL: + if image.type == AttachmentFileTypes.URL: raise NotImplementedError("Image URL not supported") with image.send_location.open("rb") as f: blob = genai.protos.Blob(mime_type=image.mime_type, data=f.read())